---
title: "Reduced Wong-Wang BOLD FC Optimization"
subtitle: "Fitting Functional Connectivity Using Network Dynamics and BOLD Simulation"
format:
html:
code-fold: false
toc: true
echo: false
embed-resources: true
fig-width: 8
out-width: "100%"
jupyter: python3
execute:
cache: true
---
Try this notebook interactively:
[Download .ipynb](https://github.com/virtual-twin/tvboptim/blob/main/docs/workflows/RWW.ipynb){.btn .btn-primary download="RWW.ipynb"}
[Download .qmd](RWW.qmd){.btn .btn-secondary download="RWW.qmd"}
[Open in Colab](https://colab.research.google.com/github/virtual-twin/tvboptim/blob/main/docs/workflows/RWW.ipynb){.btn .btn-warning target="_blank"}
## Introduction
This tutorial demonstrates how to use TVB-Optim to fit functional connectivity (FC) data from resting-state fMRI. We use the Reduced Wong-Wang (RWW) neural mass model to simulate brain activity, convert it to BOLD signal using a hemodynamic response function, and optimize model parameters to match empirical FC patterns.
```{python}
#| output: false
#| echo: false
# Install dependencies if running in Google Colab
try:
import google.colab
print("Running in Google Colab - installing dependencies...")
!pip install -q tvboptim
print("✓ Dependencies installed!")
except ImportError:
pass # Not in Colab, assume dependencies are available
```
The workflow includes:
- Building a whole-brain network with the RWW model
- Simulating BOLD signal from neural activity
- Computing functional connectivity from BOLD
- Optimizing global and region-specific parameters to fit target FC
```{python}
#| output: false
#| code-fold: true
#| code-summary: "Environment Setup and Imports"
#| echo: true
# Set up environment
import os
import time
cpu = True
if cpu:
N = 8
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'
# Import all required libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import jax
import jax.numpy as jnp
import copy
import optax
from scipy import io
# Jax enable x64
jax.config.update("jax_enable_x64", True)
# Import from tvboptim
from tvboptim.types import Parameter, Space, GridAxis
from tvboptim.types.stateutils import show_parameters
from tvboptim.utils import set_cache_path, cache
from tvboptim.execution import ParallelExecution, SequentialExecution
from tvboptim.optim.optax import OptaxOptimizer
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, SavingCallback
# Network dynamics imports
from tvboptim.experimental.network_dynamics import Network, solve, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling, FastLinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph, DenseGraph
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.data import load_structural_connectivity, load_functional_connectivity
# BOLD monitoring
from tvboptim.observations.tvb_monitors.bold import Bold
# Observation functions
from tvboptim.observations.observation import compute_fc, fc_corr, rmse
# Set cache path for tvboptim
set_cache_path("./rww")
```
We enable 64-bit precision to get reliable gradient information.
```{python}
#| output: false
#| echo: true
jax.config.update("jax_enable_x64", True)
```
## Loading Structural Data and Target FC
We load the Desikan-Killiany parcellation structural connectivity and empirical functional connectivity from resting-state fMRI data.
```{python}
#| output: false
#| echo: true
#| code-fold: true
#| code-summary: "Load structural connectivity and target FC"
# Load structural connectivity with region labels
# No delays for this model (instantaneous coupling)
weights, lengths, region_labels = load_structural_connectivity(name="dk_average")
# Normalize weights to [0, 1] range
weights = weights / np.max(weights)
n_nodes = weights.shape[0]
# Load empirical functional connectivity as optimization target
fc_target = load_functional_connectivity(name="dk_average")
```
```{python}
#| label: fig-connectivity
#| fig-cap: "**Structural connectivity matrices.** Left: Normalized connection weights showing the strength of white matter connections between brain regions. Right: Tract lengths in millimeters representing the physical distance of fiber pathways."
#| code-fold: true
#| code-summary: "Show plotting code"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4.05), sharey=True)
im1 = ax1.imshow(weights, cmap="cividis", vmax=0.5)
ax1.set_title("Structural Connectivity")
ax1.set_xlabel("Region")
ax1.set_ylabel("Region")
cbar1 = fig.colorbar(im1, ax=ax1, shrink=0.74, label="Connection Strength [a.u.]", extend='max')
im2 = ax2.imshow(lengths, cmap="cividis")
ax2.set_title("Tract Lengths")
ax2.set_xlabel("Region")
cbar2 = fig.colorbar(im2, ax=ax2, shrink=0.74, label="Tract Length [mm]")
plt.tight_layout()
```
## The Reduced Wong-Wang Model
The Reduced Wong-Wang model is a biophysically-based neural mass model that describes the dynamics of NMDA-mediated synaptic gating. It captures the slow dynamics relevant for resting-state fMRI and has been widely used for modeling whole-brain functional connectivity.
The model describes the evolution of synaptic gating variable S:
$$\frac{dS}{dt} = -\frac{S}{\tau_s} + (1-S) \cdot H(x) \cdot \gamma$$
where $x = w \cdot J_N \cdot S + I_o + G \cdot c$ combines local recurrence ($w$), external input ($I_o$), and long-range coupling ($G \cdot c$), and $H(x)$ is a sigmoidal transfer function.
Key parameters:
- `w`: Excitatory recurrence strength (local feedback)
- `I_o`: External input current
- `G` (coupling strength): Global scaling of long-range connections
## Building the Network Model
We combine the RWW dynamics with structural connectivity to create a whole-brain network model.
```{python}
#| echo: true
#| output: false
# Create network components
graph = DenseGraph(weights, region_labels=region_labels)
dynamics = ReducedWongWang(w=0.5, I_o=0.32, INITIAL_STATE=(0.3,))
coupling = FastLinearCoupling(local_states=["S"], G=0.5)
noise = AdditiveNoise(sigma=0.00283, apply_to="S")
# Assemble the network
network = Network(
dynamics=dynamics,
coupling={'instant': coupling},
graph=graph,
noise=noise
)
```
## Preparing and Running the Simulation
We prepare the network for simulation and run an initial transient to reach a quasi-stationary state.
```{python}
#| echo: true
# Prepare simulation: compile model and initialize state
t1 = 120_000 # Total simulation duration (ms) - 2 minutes
dt = 4.0 # Integration timestep (ms)
model, state = prepare(network, Heun(), t1=t1, dt=dt)
# First simulation: run transient to reach quasi-stationary state
result_init = model(state)
# Update network with final state as new initial conditions
network.update_history(result_init)
model, state = prepare(network, Heun(), t1=t1, dt=dt)
# Second simulation: quasi-stationary dynamics
result = model(state)
```
## Computing BOLD Signal
We convert the neural activity (synaptic gating S) to simulated BOLD signal using a hemodynamic response function. The BOLD monitor downsamples the neural activity and convolves it with a canonical HRF kernel.
```{python}
#| echo: true
# Create BOLD monitor with standard parameters
bold_monitor = Bold(
period=1000.0, # BOLD sampling period (1 TR = 1000 ms)
downsample_period=4.0, # Intermediate downsampling matches dt
voi=0, # Monitor first state variable (S)
history=result_init # Use initial state as warm start
)
# Apply BOLD monitor to simulation result
bold_result = bold_monitor(result)
```
```{python}
#| label: fig-timeseries
#| fig-cap: "**Neural activity and BOLD signal time series.** Left: Raw synaptic gating variable (S) showing fast neural dynamics over 1 second. Right: Simulated BOLD signal showing slow hemodynamic response over 60 seconds. Each line represents one brain region, colored by mean activity level."
#| code-fold: true
#| code-summary: "Show plotting code"
from matplotlib.colors import Normalize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 3.0375))
# Plot raw neural activity (first 1000 ms)
t_max_idx = int(1000 / dt)
time_raw = result.time[:t_max_idx]
data_raw = result.data[:t_max_idx, 0, :]
num_lines = data_raw.shape[1]
cmap = plt.cm.cividis
mean_values = np.mean(data_raw, axis=0)
norm = Normalize(vmin=np.min(mean_values), vmax=np.max(mean_values))
for i in range(num_lines):
color = cmap(norm(mean_values[i]))
ax1.plot(time_raw, data_raw[:, i], color=color, linewidth=0.5)
ax1.text(0.95, 0.95, "Raw Neural Activity", transform=ax1.transAxes, fontsize=10,
ha='right', va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))
ax1.set_xlabel("Time [ms]")
ax1.set_ylabel("S [a.u.]")
# Plot BOLD signal (first 60 TRs)
t_bold_max = 60
time_bold = bold_result.time[:t_bold_max]
data_bold = bold_result.data[:t_bold_max, 0, :]
num_lines = data_bold.shape[1]
mean_values = np.mean(data_bold, axis=0)
norm = Normalize(vmin=np.min(mean_values), vmax=np.max(mean_values))
for i in range(num_lines):
color = cmap(norm(mean_values[i]))
ax2.plot(time_bold, data_bold[:, i], color=color, linewidth=0.8)
ax2.text(0.95, 0.95, "BOLD Signal", transform=ax2.transAxes, fontsize=10,
ha='right', va='top', bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8))
ax2.set_xlabel("Time [s]")
ax2.set_ylabel("BOLD [a.u.]")
plt.tight_layout()
```
## Defining Observations and Loss
Functional connectivity (FC) measures the temporal correlation between BOLD signals from different brain regions. We define an observation function that simulates BOLD and computes FC, and a loss function that quantifies the mismatch with empirical FC.
```{python}
#| echo: true
def observation(state):
"""Compute functional connectivity from simulated BOLD signal."""
# Run simulation
result = model(state)
# Convert to BOLD
bold = bold_monitor(result)
# Compute FC, skipping first 20 TRs to avoid transient effects
fc = compute_fc(bold, skip_t=20)
return fc
def loss(state):
"""Compute RMSE between simulated and empirical FC."""
fc = observation(state)
return rmse(fc, fc_target)
```
```{python}
#| label: fig-initial-fc
#| fig-cap: "**Initial functional connectivity comparison.** Left: Empirical FC from resting-state fMRI serving as optimization target. Right: Simulated FC from initial model parameters showing poor correlation with target (r = correlation coefficient between the two matrices)."
#| code-fold: true
#| code-summary: "Show plotting code"
# Calculate initial FC
fc_initial = np.array(observation(state))
# Create figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 3.54375))
# Plot both FC matrices
for ax_current, fc_matrix, title_prefix in zip([ax1, ax2], [fc_target, fc_initial], ["Target FC", "Initial FC"]):
fc_matrix = np.copy(fc_matrix)
np.fill_diagonal(fc_matrix, np.nan) # Set diagonal to NaN
im = ax_current.imshow(fc_matrix, cmap='cividis', vmax=0.9)
ax_current.set_xticks([])
ax_current.set_yticks([])
ax_current.set_xlabel('')
ax_current.set_ylabel('')
# Calculate correlation for title
if title_prefix == "Initial FC":
corr_value = fc_corr(fc_initial, fc_target)
title = f"{title_prefix}\nr = {corr_value:.3f}"
else:
title = title_prefix
# Add title as annotation
ax_current.annotate(title,
xy=(0.25, 0.95),
xycoords='axes fraction',
ha='left', va='top',
fontsize=10, fontweight='bold',
color='black',
bbox=dict(boxstyle='round,pad=0.3',
facecolor='white', alpha=0.9))
plt.tight_layout()
```
## Parameter Exploration
Before optimization, we explore how the model parameters affect FC quality. We systematically vary the excitatory recurrence `w` and global coupling strength `G` across a 2D grid and compute the loss for each combination.
```{python}
#| echo: true
#| output: false
# Create grid for parameter exploration
n = 32
# Set up parameter axes for exploration
grid_state = copy.deepcopy(state)
grid_state.dynamics.w = GridAxis(0.001, 0.7, n)
grid_state.coupling.instant.G = GridAxis(0.001, 0.7, n)
# Create space (product creates all combinations of w and G)
grid = Space(grid_state, mode="product")
@cache("explore", redo=False)
def explore():
# Parallel execution across 8 processes
exec = ParallelExecution(loss, grid, n_pmap=8)
# Alternative: Sequential execution
# exec = SequentialExecution(loss, grid)
return exec.run()
exploration_results = explore()
```
```{python}
#| label: fig-exploration
#| fig-cap: "**Parameter landscape exploration.** The heatmap shows FC fitting loss (RMSE) across the parameter space of excitatory recurrence (w) and global coupling (G). Dark regions indicate better FC fits. The landscape reveals an optimal region where both parameters balance to reproduce empirical connectivity patterns."
#| code-fold: true
#| code-summary: "Show plotting code"
# Prepare data for visualization
pc = grid.collect()
G_vals = pc.coupling.instant.G.flatten()
w_vals = pc.dynamics.w.flatten()
# Get parameter ranges
G_min, G_max = min(G_vals), max(G_vals)
w_min, w_max = min(w_vals), max(w_vals)
# Create figure and axis
fig, ax = plt.subplots(figsize=(8.1, 4.05))
# Create the heatmap
im = ax.imshow(jnp.stack(exploration_results).reshape(n, n).T,
cmap='cividis_r',
extent=[G_min, G_max, w_min, w_max],
origin='lower',
aspect='auto',
interpolation='none')
# Add colorbar and labels
cbar = plt.colorbar(im, label="Loss (RMSE)")
ax.set_xlabel('Global Coupling (G)')
ax.set_ylabel('Excitatory Recurrence (w)')
ax.set_title("Parameter Exploration")
plt.tight_layout()
```
## Gradient-Based Optimization
We use gradient-based optimization to find the best global parameters (same values for all regions) that minimize the FC mismatch. JAX's automatic differentiation computes gradients through the entire simulation pipeline.
```{python}
#| echo: true
#| output: false
# Mark parameters as optimizable
state.coupling.instant.G = Parameter(state.coupling.instant.G)
state.dynamics.w = Parameter(state.dynamics.w)
# Create and run optimizer
cb = MultiCallback([
DefaultPrintCallback(every=10),
SavingCallback(key="state", save_fun=lambda *args: args[1]) # Save updated state
])
@cache("optimize", redo=False)
def optimize():
opt = OptaxOptimizer(loss, optax.adam(0.01, b2=0.9999), callback=cb)
fitted_state, fitting_data = opt.run(state, max_steps=300)
return fitted_state, fitting_data
fitted_state, fitting_data = optimize()
```
```{python}
#| label: fig-optimization-trajectory
#| fig-cap: "**Optimization trajectory in parameter space.** White points show the path taken by gradient descent from initial parameters (top marker) to optimized values (bottom marker). The optimizer efficiently navigates the loss landscape to find parameter combinations that yield good FC fits."
#| code-fold: true
#| code-summary: "Show plotting code"
# Prepare data for visualization
pc = grid.collect()
G_vals = pc.coupling.instant.G
w_vals = pc.dynamics.w
# Get parameter ranges
G_min, G_max = min(G_vals), max(G_vals)
w_min, w_max = min(w_vals), max(w_vals)
# Create figure and axis
fig, ax = plt.subplots(figsize=(8.1, 4.725))
# Create the heatmap
im = ax.imshow(jnp.stack(exploration_results).reshape(n, n).T,
cmap='cividis_r',
extent=[G_min, G_max, w_min, w_max],
origin='lower',
aspect='auto',
interpolation='none')
# Mark initial value
G_init = state.coupling.instant.G.value
w_init = state.dynamics.w.value
ax.scatter(G_init, w_init, color='white', s=100, marker='o',
edgecolors='k', linewidths=2, zorder=5)
# Add annotation
ax.annotate('Initial', xy=(G_init, w_init),
xytext=(G_init, w_init+0.05*(w_max-w_min)),
color='white', fontweight='bold', ha='center', zorder=5,
path_effects=[path_effects.withStroke(linewidth=3, foreground='black')])
# Add fitted value point
G_fit = fitted_state.coupling.instant.G.value
w_fit = fitted_state.dynamics.w.value
ax.scatter(G_fit, w_fit, color='white', s=100, marker='o',
edgecolors='k', linewidths=2, zorder=5)
# Add annotation for the fitted value
ax.annotate('Optimized', xy=(G_fit, w_fit),
xytext=(G_fit, w_fit-0.08*(w_max-w_min)),
color='white', fontweight='bold', ha='center', zorder=5,
path_effects=[path_effects.withStroke(linewidth=3, foreground='black')])
# Add optimization path points
G_route = np.array([ds.coupling.instant.G.value for ds in fitting_data["state"].save])
w_route = np.array([ds.dynamics.w.value for ds in fitting_data["state"].save])
ax.scatter(G_route[::2], w_route[::2], color='white', s=15, marker='o',
linewidths=1, zorder=4, edgecolors='k')
# Remove axes ticks and labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('')
ax.set_ylabel('')
plt.tight_layout()
```
## Heterogeneous Optimization
Global parameters (same for all regions) may not capture region-specific variations needed for optimal FC fit. We now make parameters heterogeneous: each brain region gets its own `w` and `I_o` values, while keeping `G` global.
```{python}
#| echo: true
# Copy already optimized state and make parameters regional
fitted_state_het = copy.deepcopy(fitted_state)
# Make w regional (one value per node)
fitted_state_het.dynamics.w.shape = (n_nodes,)
# Also make I_o regional and mark as optimizable
fitted_state_het.dynamics.I_o = Parameter(fitted_state_het.dynamics.I_o)
fitted_state_het.dynamics.I_o.shape = (n_nodes,)
# Keep global coupling fixed at optimized value
fitted_state_het.coupling.instant.G = fitted_state_het.coupling.instant.G.value
show_parameters(fitted_state_het)
```
```{python}
#| echo: true
#| output: false
@cache("optimize_het", redo=False)
def optimize_het():
opt = OptaxOptimizer(loss, optax.adam(0.004, b2=0.999), callback=cb)
fitted_state, fitting_data = opt.run(fitted_state_het, max_steps=200)
return fitted_state, fitting_data
fitted_state_het, fitting_data_het = optimize_het()
```
## Comparing Global vs Regional Parameters
Let's compare the FC quality from global (homogeneous) vs regional (heterogeneous) parameter fits.
```{python}
#| output: false
#| echo: true
# Compute FC for both optimization strategies
fc_global = np.array(observation(fitted_state))
fc_regional = np.array(observation(fitted_state_het))
```
```{python}
#| label: fig-fc-comparison
#| fig-cap: "**Comparison of functional connectivity matrices.** Left: Empirical target FC from resting-state fMRI. Middle: FC from global parameter optimization. Right: FC from regional parameter optimization. The correlation coefficient (r) quantifies the similarity to the target FC. Regional parameters achieve better fit by accounting for local variations."
#| code-fold: true
#| code-summary: "Show plotting code"
# Create the figure with three subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(8.1, 3.54375))
# Plot the FC matrices
for ax_current, fc_matrix, title_prefix in zip([ax1, ax2, ax3], [fc_target, fc_global, fc_regional], ["Target FC", "Global Parameters", "Regional Parameters"]):
fc_matrix = np.copy(fc_matrix)
np.fill_diagonal(fc_matrix, np.nan) # Set diagonal to NaN
im = ax_current.imshow(fc_matrix, cmap='cividis', vmax=1.0)
ax_current.set_xticks([])
ax_current.set_yticks([])
ax_current.set_xlabel('')
ax_current.set_ylabel('')
# Calculate correlation for title (if not target)
if title_prefix == "Target FC":
title = title_prefix
elif title_prefix == "Global Parameters":
corr_value = fc_corr(fc_global, fc_target)
title = f"{title_prefix}\nr = {corr_value:.3f}"
else:
corr_value = fc_corr(fc_regional, fc_target)
title = f"{title_prefix}\nr = {corr_value:.3f}"
# Set title
ax_current.set_title(title, fontsize=10, fontweight='bold')
plt.tight_layout()
```
```{python}
#| label: fig-fc-scatter
#| fig-cap: "**Scatter plots comparing fitted vs empirical FC.** Each point represents one pairwise connection between brain regions. Left: Global parameter fit shows good overall correlation. Right: Regional parameter fit shows improved correlation with reduced scatter, indicating better reproduction of the empirical FC structure. The diagonal line represents perfect agreement."
#| code-fold: true
#| code-summary: "Show plotting code"
# Create figure with two scatter plots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 5.4), sharey=True, sharex=True)
# Get upper triangular indices (exclude diagonal)
triu_idx = np.triu_indices_from(fc_target, k=1)
# Extract upper triangular values
fc_target_triu = fc_target[triu_idx]
fc_global_triu = fc_global[triu_idx]
fc_regional_triu = fc_regional[triu_idx]
# Plot global parameters
ax1.scatter(fc_target_triu, fc_global_triu, alpha=0.3, s=10, color='royalblue', edgecolors='none')
ax1.plot([fc_target_triu.min(), fc_target_triu.max()],
[fc_target_triu.min(), fc_target_triu.max()],
'k--', linewidth=1.5, label='Perfect fit')
corr_global = fc_corr(fc_global, fc_target)
ax1.set_xlabel('Empirical FC')
ax1.set_ylabel('Simulated FC')
ax1.set_title(f'Global Parameters\nr = {corr_global:.3f}')
ax1.grid(True, alpha=0.3)
ax1.set_aspect('equal', adjustable='box')
# Plot regional parameters
ax2.scatter(fc_target_triu, fc_regional_triu, alpha=0.3, s=10, color='royalblue', edgecolors='none')
ax2.plot([fc_target_triu.min(), fc_target_triu.max()],
[fc_target_triu.min(), fc_target_triu.max()],
'k--', linewidth=1.5, label='Perfect fit')
corr_regional = fc_corr(fc_regional, fc_target)
ax2.set_xlabel('Empirical FC')
ax2.set_ylabel('Simulated FC')
ax2.set_title(f'Regional Parameters\nr = {corr_regional:.3f}')
ax2.grid(True, alpha=0.3)
ax2.set_aspect('equal', adjustable='box')
plt.tight_layout()
```
## Fitted Heterogeneous Parameters
Let's examine the fitted region-specific parameters and their relationship to structural connectivity.
```{python}
#| label: fig-fitted-params
#| fig-cap: "**Fitted heterogeneous parameters.** Left: Fitted excitatory recurrence (w) for each region plotted against mean incoming structural connectivity strength. Right: Fitted external input (I_o) vs mean connectivity. Dashed lines show the global optimization values for reference. Regions with stronger structural connections tend to require different parameter values to achieve optimal FC fit, demonstrating the importance of region-specific tuning."
#| code-fold: true
#| code-summary: "Show plotting code"
# Calculate mean incoming connectivity for each region
mean_connectivity = np.mean(weights, axis=1)
# Extract fitted regional parameters
w_fitted = fitted_state_het.dynamics.w.value.flatten()
I_o_fitted = fitted_state_het.dynamics.I_o.value.flatten()
# Get global optimization values for reference
w_global = fitted_state.dynamics.w.value
I_o_global = fitted_state.dynamics.I_o # Not optimized in global fit, but initial value
# Create figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 3.24))
# Plot w vs mean connectivity
ax1.scatter(mean_connectivity, w_fitted, alpha=0.7, s=30, color='royalblue', edgecolors='k', linewidths=0.5)
ax1.axhline(w_global, color='red', linestyle='--', linewidth=2, label=f'Global w = {w_global:.3f}')
ax1.set_xlabel('Mean Incoming Connectivity')
ax1.set_ylabel('Fitted w (Excitatory Recurrence)')
ax1.set_title('Regional Excitatory Recurrence Parameters')
ax1.legend(loc='best')
ax1.grid(True, alpha=0.3)
# Plot I_o vs mean connectivity
ax2.scatter(mean_connectivity, I_o_fitted, alpha=0.7, s=30, color='royalblue', edgecolors='k', linewidths=0.5)
ax2.axhline(I_o_global, color='red', linestyle='--', linewidth=2, label=f'Initial I_o = {I_o_global:.3f}')
ax2.set_xlabel('Mean Incoming Connectivity')
ax2.set_ylabel('Fitted I_o (External Input)')
ax2.set_title('Regional External Input Parameters')
ax2.legend(loc='best')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
```
::: {.callout-note}
## Parameter Constraints
Notice that some regions have negative `w` values, which changes the biological interpretation of the parameter from excitatory to inhibitory recurrence. While this may be mathematically valid for achieving good FC fits, it violates the intended model constraints where `w` represents excitatory feedback strength.
In a further refinement step, we could use `BoundedParameter` to enforce physiological constraints during optimization:
```python
# Example of using bounded parameters (not executed here)
from tvboptim.types import BoundedParameter
# Constrain w to positive values only
state.dynamics.w = BoundedParameter(
state.dynamics.w,
lower_bound=0.0, # Enforce excitatory nature
upper_bound=1.0 # Maximum recurrence strength
)
```
This would ensure that the optimizer only explores biologically plausible parameter regions while still allowing heterogeneous regional variations.
:::
## Summary
This tutorial demonstrated the complete workflow for fitting brain network models to functional connectivity data using TVB-Optim:
1. **Model Construction**: We built a whole-brain network using the Reduced Wong-Wang neural mass model with structural connectivity.
2. **BOLD Simulation**: We converted neural activity to simulated BOLD signal using a hemodynamic response function, matching the temporal scale of fMRI data.
3. **FC Computation**: We computed functional connectivity from simulated BOLD and defined a loss function measuring mismatch with empirical FC.
4. **Parameter Exploration**: We systematically explored the parameter space to understand the relationship between model parameters and FC quality.
5. **Gradient-Based Optimization**: We used automatic differentiation through the entire simulation pipeline to optimize global parameters.
6. **Heterogeneous Parameters**: We refined the model with region-specific parameters, achieving better FC fits by accounting for regional variations in neural dynamics.
This approach showcases TVB-Optim's capability to perform end-to-end optimization of complex, biophysically-realistic brain network models with automatic differentiation through stochastic simulations and signal processing pipelines.