Reduced Wong-Wang BOLD FC Optimization

Fitting Functional Connectivity Using Network Dynamics and BOLD Simulation

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

This tutorial demonstrates how to use TVB-Optim to fit functional connectivity (FC) data from resting-state fMRI. We use the Reduced Wong-Wang (RWW) neural mass model to simulate brain activity, convert it to BOLD signal using a hemodynamic response function, and optimize model parameters to match empirical FC patterns.

The workflow includes:

  • Building a whole-brain network with the RWW model
  • Simulating BOLD signal from neural activity
  • Computing functional connectivity from BOLD
  • Optimizing global and region-specific parameters to fit target FC
Environment Setup and Imports
# Set up environment
import os
import time
cpu = True
if cpu:
    N = 8
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'

# Import all required libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import jax
import jax.numpy as jnp
import copy
import optax
from scipy import io

# Jax enable x64
jax.config.update("jax_enable_x64", True)

# Import from tvboptim
from tvboptim.types import Parameter, Space, GridAxis
from tvboptim.types.stateutils import show_parameters
from tvboptim.utils import set_cache_path, cache
from tvboptim.execution import ParallelExecution, SequentialExecution
from tvboptim.optim.optax import OptaxOptimizer
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, SavingCallback

# Network dynamics imports
from tvboptim.experimental.network_dynamics import Network, solve, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling, FastLinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph, DenseGraph
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.data import load_structural_connectivity, load_functional_connectivity

# BOLD monitoring
from tvboptim.observations.tvb_monitors.bold import Bold

# Observation functions
from tvboptim.observations.observation import compute_fc, fc_corr, rmse

# Set cache path for tvboptim
set_cache_path("./rww")

We enable 64-bit precision to get reliable gradient information.

jax.config.update("jax_enable_x64", True)

Loading Structural Data and Target FC

We load the Desikan-Killiany parcellation structural connectivity and empirical functional connectivity from resting-state fMRI data.

Load structural connectivity and target FC
# Load structural connectivity with region labels
# No delays for this model (instantaneous coupling)
weights, lengths, region_labels = load_structural_connectivity(name="dk_average")

# Normalize weights to [0, 1] range
weights = weights / np.max(weights)
n_nodes = weights.shape[0]


# Load empirical functional connectivity as optimization target
fc_target = load_functional_connectivity(name="dk_average")
Figure 1: Structural connectivity matrices. Left: Normalized connection weights showing the strength of white matter connections between brain regions. Right: Tract lengths in millimeters representing the physical distance of fiber pathways.

The Reduced Wong-Wang Model

The Reduced Wong-Wang model is a biophysically-based neural mass model that describes the dynamics of NMDA-mediated synaptic gating. It captures the slow dynamics relevant for resting-state fMRI and has been widely used for modeling whole-brain functional connectivity.

The model describes the evolution of synaptic gating variable S:

\[\frac{dS}{dt} = -\frac{S}{\tau_s} + (1-S) \cdot H(x) \cdot \gamma\]

where \(x = w \cdot J_N \cdot S + I_o + G \cdot c\) combines local recurrence (\(w\)), external input (\(I_o\)), and long-range coupling (\(G \cdot c\)), and \(H(x)\) is a sigmoidal transfer function.

Key parameters:

  • w: Excitatory recurrence strength (local feedback)
  • I_o: External input current
  • G (coupling strength): Global scaling of long-range connections

Building the Network Model

We combine the RWW dynamics with structural connectivity to create a whole-brain network model.

# Create network components
graph = DenseGraph(weights, region_labels=region_labels)
dynamics = ReducedWongWang(w=0.5, I_o=0.32, INITIAL_STATE=(0.3,))
coupling = FastLinearCoupling(local_states=["S"], G=0.5)
noise = AdditiveNoise(sigma=0.00283, apply_to="S")

# Assemble the network
network = Network(
    dynamics=dynamics,
    coupling={'instant': coupling},
    graph=graph,
    noise=noise
)

Preparing and Running the Simulation

We prepare the network for simulation and run an initial transient to reach a quasi-stationary state.

# Prepare simulation: compile model and initialize state
t1 = 120_000  # Total simulation duration (ms) - 2 minutes
dt = 4.0      # Integration timestep (ms)
model, state = prepare(network, Heun(), t1=t1, dt=dt)

# First simulation: run transient to reach quasi-stationary state
result_init = model(state)

# Update network with final state as new initial conditions
network.update_history(result_init)
model, state = prepare(network, Heun(), t1=t1, dt=dt)

# Second simulation: quasi-stationary dynamics
result = model(state)

Computing BOLD Signal

We convert the neural activity (synaptic gating S) to simulated BOLD signal using a hemodynamic response function. The BOLD monitor downsamples the neural activity and convolves it with a canonical HRF kernel.

# Create BOLD monitor with standard parameters
bold_monitor = Bold(
    period=1000.0,          # BOLD sampling period (1 TR = 1000 ms)
    downsample_period=4.0,  # Intermediate downsampling matches dt
    voi=0,                  # Monitor first state variable (S)
    history=result_init     # Use initial state as warm start
)

# Apply BOLD monitor to simulation result
bold_result = bold_monitor(result)
Figure 2: Neural activity and BOLD signal time series. Left: Raw synaptic gating variable (S) showing fast neural dynamics over 1 second. Right: Simulated BOLD signal showing slow hemodynamic response over 60 seconds. Each line represents one brain region, colored by mean activity level.

Defining Observations and Loss

Functional connectivity (FC) measures the temporal correlation between BOLD signals from different brain regions. We define an observation function that simulates BOLD and computes FC, and a loss function that quantifies the mismatch with empirical FC.

def observation(state):
    """Compute functional connectivity from simulated BOLD signal."""
    # Run simulation
    result = model(state)
    # Convert to BOLD
    bold = bold_monitor(result)
    # Compute FC, skipping first 20 TRs to avoid transient effects
    fc = compute_fc(bold, skip_t=20)
    return fc

def loss(state):
    """Compute RMSE between simulated and empirical FC."""
    fc = observation(state)
    return rmse(fc, fc_target)
Figure 3: Initial functional connectivity comparison. Left: Empirical FC from resting-state fMRI serving as optimization target. Right: Simulated FC from initial model parameters showing poor correlation with target (r = correlation coefficient between the two matrices).

Parameter Exploration

Before optimization, we explore how the model parameters affect FC quality. We systematically vary the excitatory recurrence w and global coupling strength G across a 2D grid and compute the loss for each combination.

# Create grid for parameter exploration
n = 32

# Set up parameter axes for exploration
grid_state = copy.deepcopy(state)
grid_state.dynamics.w = GridAxis(0.001, 0.7, n)
grid_state.coupling.instant.G = GridAxis(0.001, 0.7, n)

# Create space (product creates all combinations of w and G)
grid = Space(grid_state, mode="product")

@cache("explore", redo=False)
def explore():
    # Parallel execution across 8 processes
    exec = ParallelExecution(loss, grid, n_pmap=8)
    # Alternative: Sequential execution
    # exec = SequentialExecution(loss, grid)
    return exec.run()

exploration_results = explore()
Figure 4: Parameter landscape exploration. The heatmap shows FC fitting loss (RMSE) across the parameter space of excitatory recurrence (w) and global coupling (G). Dark regions indicate better FC fits. The landscape reveals an optimal region where both parameters balance to reproduce empirical connectivity patterns.

Gradient-Based Optimization

We use gradient-based optimization to find the best global parameters (same values for all regions) that minimize the FC mismatch. JAX’s automatic differentiation computes gradients through the entire simulation pipeline.

# Mark parameters as optimizable
state.coupling.instant.G = Parameter(state.coupling.instant.G)
state.dynamics.w = Parameter(state.dynamics.w)

# Create and run optimizer
cb = MultiCallback([
    DefaultPrintCallback(every=10),
    SavingCallback(key="state", save_fun=lambda *args: args[1])  # Save updated state
])

@cache("optimize", redo=False)
def optimize():
    opt = OptaxOptimizer(loss, optax.adam(0.01, b2=0.9999), callback=cb)
    fitted_state, fitting_data = opt.run(state, max_steps=300)
    return fitted_state, fitting_data

fitted_state, fitting_data = optimize()
Figure 5: Optimization trajectory in parameter space. White points show the path taken by gradient descent from initial parameters (top marker) to optimized values (bottom marker). The optimizer efficiently navigates the loss landscape to find parameter combinations that yield good FC fits.

Heterogeneous Optimization

Global parameters (same for all regions) may not capture region-specific variations needed for optimal FC fit. We now make parameters heterogeneous: each brain region gets its own w and I_o values, while keeping G global.

# Copy already optimized state and make parameters regional
fitted_state_het = copy.deepcopy(fitted_state)

# Make w regional (one value per node)
fitted_state_het.dynamics.w.shape = (n_nodes,)

# Also make I_o regional and mark as optimizable
fitted_state_het.dynamics.I_o = Parameter(fitted_state_het.dynamics.I_o)
fitted_state_het.dynamics.I_o.shape = (n_nodes,)

# Keep global coupling fixed at optimized value
fitted_state_het.coupling.instant.G = fitted_state_het.coupling.instant.G.value

show_parameters(fitted_state_het)
Parameters
├── _internal: Bunch
├── coupling: Bunch
├── dynamics
│   ├── I_o
│   │   └── value: [0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32
 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32
 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32
 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32
 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32
 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32 0.32]
│   └── w
│       └── value: [0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702
 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702 0.09014702]
├── external: Bunch
├── graph: DenseGraph
├── initial_state: Bunch
└── noise: Bunch
@cache("optimize_het", redo=False)
def optimize_het():
    opt = OptaxOptimizer(loss, optax.adam(0.004, b2=0.999), callback=cb)
    fitted_state, fitting_data = opt.run(fitted_state_het, max_steps=200)
    return fitted_state, fitting_data

fitted_state_het, fitting_data_het = optimize_het()

Comparing Global vs Regional Parameters

Let’s compare the FC quality from global (homogeneous) vs regional (heterogeneous) parameter fits.

# Compute FC for both optimization strategies
fc_global = np.array(observation(fitted_state))
fc_regional = np.array(observation(fitted_state_het))
Figure 6: Comparison of functional connectivity matrices. Left: Empirical target FC from resting-state fMRI. Middle: FC from global parameter optimization. Right: FC from regional parameter optimization. The correlation coefficient (r) quantifies the similarity to the target FC. Regional parameters achieve better fit by accounting for local variations.
Figure 7: Scatter plots comparing fitted vs empirical FC. Each point represents one pairwise connection between brain regions. Left: Global parameter fit shows good overall correlation. Right: Regional parameter fit shows improved correlation with reduced scatter, indicating better reproduction of the empirical FC structure. The diagonal line represents perfect agreement.

Fitted Heterogeneous Parameters

Let’s examine the fitted region-specific parameters and their relationship to structural connectivity.

Figure 8: Fitted heterogeneous parameters. Left: Fitted excitatory recurrence (w) for each region plotted against mean incoming structural connectivity strength. Right: Fitted external input (I_o) vs mean connectivity. Dashed lines show the global optimization values for reference. Regions with stronger structural connections tend to require different parameter values to achieve optimal FC fit, demonstrating the importance of region-specific tuning.
NoteParameter Constraints

Notice that some regions have negative w values, which changes the biological interpretation of the parameter from excitatory to inhibitory recurrence. While this may be mathematically valid for achieving good FC fits, it violates the intended model constraints where w represents excitatory feedback strength.

In a further refinement step, we could use BoundedParameter to enforce physiological constraints during optimization:

# Example of using bounded parameters (not executed here)
from tvboptim.types import BoundedParameter

# Constrain w to positive values only
state.dynamics.w = BoundedParameter(
    state.dynamics.w,
    lower_bound=0.0,    # Enforce excitatory nature
    upper_bound=1.0     # Maximum recurrence strength
)

This would ensure that the optimizer only explores biologically plausible parameter regions while still allowing heterogeneous regional variations.

Summary

This tutorial demonstrated the complete workflow for fitting brain network models to functional connectivity data using TVB-Optim:

  1. Model Construction: We built a whole-brain network using the Reduced Wong-Wang neural mass model with structural connectivity.

  2. BOLD Simulation: We converted neural activity to simulated BOLD signal using a hemodynamic response function, matching the temporal scale of fMRI data.

  3. FC Computation: We computed functional connectivity from simulated BOLD and defined a loss function measuring mismatch with empirical FC.

  4. Parameter Exploration: We systematically explored the parameter space to understand the relationship between model parameters and FC quality.

  5. Gradient-Based Optimization: We used automatic differentiation through the entire simulation pipeline to optimize global parameters.

  6. Heterogeneous Parameters: We refined the model with region-specific parameters, achieving better FC fits by accounting for regional variations in neural dynamics.

This approach showcases TVB-Optim’s capability to perform end-to-end optimization of complex, biophysically-realistic brain network models with automatic differentiation through stochastic simulations and signal processing pipelines.