Jansen-Rit Peak Frequency Optimization

Reproducing MEG Resting-State Frequency Gradients with Network Dynamics

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

This tutorial demonstrates how to use TVB-Optim’s Network Dynamics framework to model and optimize brain network dynamics. We aim to reproduce the frequency gradient observed in resting-state MEG data, where different brain regions exhibit characteristic peak frequencies that vary with their distance from visual cortex.

Environment Setup and Imports
# Set up environment
import os
import time
cpu = True
if cpu:
    N = 8
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'

# Import all required libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import jax
import jax.numpy as jnp
import copy
import equinox as eqx
import optax
from IPython.display import Markdown, display
import pandas as pd
import nibabel as nib
from nilearn import plotting

# Import from tvboptim
from tvboptim import prepare
from tvboptim.types import Parameter, Space, GridAxis
from tvboptim.types.stateutils import show_parameters
from tvboptim.utils import set_cache_path, cache
from tvboptim.execution import ParallelExecution, SequentialExecution
from tvboptim.optim.optax import OptaxOptimizer
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, SavingCallback, AbstractCallback

# network_dynamics imports
from tvboptim.experimental.network_dynamics import Network, solve, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import JansenRit
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling, SigmoidalJansenRit, DelayedSigmoidalJansenRit, AbstractCoupling
from tvboptim.experimental.network_dynamics.graph import DenseGraph, DenseDelayGraph
from tvboptim.experimental.network_dynamics.solvers import Euler, Heun
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.data import load_structural_connectivity
from tvboptim.experimental.network_dynamics.utils import print_network

# Set cache path for tvboptim
set_cache_path("./jr")

Loading Structural Data

We load the Desikan-Killiany parcellation atlas and structural connectivity data.

Load structural connectivity data
# Load volume data for visualization
# Try local path first, fall back to GitHub download for Colab
try:
    dk_info = pd.read_csv("../data/dk_average/fs_default_freesurfer_idx.csv")
    dk = nib.load("../data/dk_average/aparc+aseg-mni_09c.nii.gz")
except FileNotFoundError:
    import urllib.request
    from pathlib import Path

    # Download from GitHub to local cache
    cache_dir = Path.home() / ".cache" / "tvboptim" / "dk_average"
    cache_dir.mkdir(parents=True, exist_ok=True)
    base_url = "https://raw.githubusercontent.com/virtual-twin/tvboptim/main/docs/data/dk_average"

    for filename in ["fs_default_freesurfer_idx.csv", "aparc+aseg-mni_09c.nii.gz"]:
        local_file = cache_dir / filename
        if not local_file.exists():
            print(f"Downloading {filename}...")
            urllib.request.urlretrieve(f"{base_url}/{filename}", local_file)

    dk_info = pd.read_csv(cache_dir / "fs_default_freesurfer_idx.csv")
    dk = nib.load(cache_dir / "aparc+aseg-mni_09c.nii.gz")

dk_info.drop_duplicates(subset="freesurfer_idx", inplace=True)

# Load structural connectivity with region labels
weights, lengths, region_labels = load_structural_connectivity(name="dk_average")

Background: Resting State MEG

Experimental studies of resting-state brain activity using MEG have revealed a systematic gradient in oscillatory peak frequencies across cortical regions. Lower-frequency alpha oscillations (~7-8 Hz) dominate in higher-order association areas, while higher frequencies (~10-11 Hz) are observed in sensory cortices. This gradient follows the cortical hierarchy from sensory to association areas.

Figure 1: Resting State MEG Peak Frequency Gradient. Taken from: Mahjoory, K., Schoffelen, J.-M., Keitel, A., & Gross, J. (2020). The frequency gradient of human resting-state brain oscillations follows cortical hierarchies. eLife, 9, e53715. https://doi.org/10.7554/eLife.53715

Defining Target Frequencies

We approximate this gradient by mapping peak frequencies based on the distance from visual cortex (lateral occipital gyrus), ranging from 11 Hz at visual cortex to 7 Hz at the most distant regions.

Compute distance-based frequency targets
ab_map = np.zeros(dk.shape)
# freesurfer index where acronym == name in df dk_info
# Lateral Occipital Gyrus
idx = dk_info[dk_info.acronym == "L.LOG"].freesurfer_idx.values[0]
ab_map = np.where(dk.get_fdata() == idx, 1, ab_map)
idx = dk_info[dk_info.acronym == "R.LOG"].freesurfer_idx.values[0]
ab_map = np.where(dk.get_fdata() == idx, 1, ab_map)

region_labels = np.array(region_labels)
idx_l = np.where(region_labels == "L.LOG")[0]
idx_r = np.where(region_labels == "R.LOG")[0]
dist_from_vc = np.array(np.squeeze(0.5 * (lengths[idx_l, :] + lengths[idx_r, :])))

dist_map = np.zeros(dk.shape)

for name, dist in zip(region_labels, dist_from_vc):
    idx = dk_info[dk_info.acronym == name].freesurfer_idx.values[0]
    dist_map = np.where(dk.get_fdata() == idx, dist, dist_map)

n_nodes = 84
f_min = 7 # at max distance
f_max = 11 # at VC
min_dist = dist_from_vc.min()
max_dist = dist_from_vc.max()
delta_f = (f_max - f_min) / (max_dist - min_dist)
peak_freqs = f_max - delta_f * (dist_from_vc - min_dist)

f_map_target = np.zeros_like(dk.get_fdata())
for name, fp in zip(region_labels, peak_freqs):
    idx = dk_info[dk_info.acronym == name].freesurfer_idx.values[0]
    f_map_target = np.where(dk.get_fdata() == idx, fp, f_map_target)
Figure 2: Target frequency gradient based on anatomical distance. Left: White matter fiber tract lengths (mm) from each region to the lateral occipital gyrus (black outline), which approximates visual cortex in the Desikan-Killiany parcellation. Right: Target peak frequencies mapped from distance, with 7 Hz for the most distant regions and 11 Hz for visual cortex.

The Jansen-Rit Model

The Jansen-Rit model is a neural mass model that describes the mean activity of interconnected populations of excitatory and inhibitory neurons in a cortical column. It produces realistic alpha-band oscillations and is widely used in whole-brain modeling.

The model has three key rate parameters: a and b control the time constants of excitatory and inhibitory populations respectively, while mu sets the baseline input. These parameters determine the oscillation frequency and amplitude.

dynamics = JansenRit(a = 0.075, b = 0.075, mu = 0.15)
dynamics.plot(t1=1000, dt = 1)
Figure 3: Single Jansen-Rit oscillator dynamics. Time series showing characteristic alpha-band oscillations with parameters a=0.075 ms⁻¹, b=0.075 ms⁻¹, and μ=0.15.

Code for the Jansen-Rit Model and Sigmoidal Coupling

View source code for JansenRit and DelayedSigmoidalJansenRit
class JansenRit(AbstractDynamics):
    """Jansen-Rit model with multi-coupling support."""

    STATE_NAMES = ("y0", "y1", "y2", "y3", "y4", "y5")
    INITIAL_STATE = (0.0, 5.0, 5.0, 0.0, 0.0, 0.0)

    AUXILIARY_NAMES = ("sigm_y1_y2", "sigm_y0_1", "sigm_y0_3")

    DEFAULT_PARAMS = Bunch(
        A=3.25,          # Maximum amplitude of EPSP [mV]
        B=22.0,          # Maximum amplitude of IPSP [mV]
        a=0.1,           # Reciprocal of membrane time constant [ms^-1]
        b=0.05,          # Reciprocal of membrane time constant [ms^-1]
        v0=5.52,         # Firing threshold [mV]
        nu_max=0.0025,   # Maximum firing rate [ms^-1]
        r=0.56,          # Steepness of sigmoid [mV^-1]
        J=135.0,         # Average number of synapses
        a_1=1.0,         # Excitatory feedback probability
        a_2=0.8,         # Slow excitatory feedback probability
        a_3=0.25,        # Inhibitory feedback probability
        a_4=0.25,        # Slow inhibitory feedback probability
        mu=0.22,         # Mean input firing rate
    )

    # Multi-coupling: instantaneous and delayed
    COUPLING_INPUTS = {
        'instant': 1,
        'delayed': 1,
    }

    def dynamics(
        self,
        t: float,
        state: jnp.ndarray,
        params: Bunch,
        coupling: Bunch
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # Unpack parameters
        A, B = params.A, params.B
        a, b = params.a, params.b
        v0, nu_max, r = params.v0, params.nu_max, params.r
        J = params.J
        a_1, a_2, a_3, a_4 = params.a_1, params.a_2, params.a_3, params.a_4
        mu = params.mu

        # Unpack state variables
        y0, y1, y2, y3, y4, y5 = state[0], state[1], state[2], state[3], state[4], state[5]

        # Unpack coupling inputs
        c_instant = coupling.instant[0]
        c_delayed = coupling.delayed[0]

        # Sigmoid functions
        sigm_y1_y2 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (y1 - y2))))
        sigm_y0_1 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (a_1 * J * y0))))
        sigm_y0_3 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (a_3 * J * y0))))

        # State derivatives (both couplings add to excitatory interneuron)
        dy0_dt = y3
        dy1_dt = y4
        dy2_dt = y5
        dy3_dt = A * a * sigm_y1_y2 - 2.0 * a * y3 - a**2 * y0
        dy4_dt = A * a * (mu + a_2 * J * sigm_y0_1 + c_instant + c_delayed) - 2.0 * a * y4 - a**2 * y1
        dy5_dt = B * b * (a_4 * J * sigm_y0_3) - 2.0 * b * y5 - b**2 * y2

        # Package results
        derivatives = jnp.array([dy0_dt, dy1_dt, dy2_dt, dy3_dt, dy4_dt, dy5_dt])
        auxiliaries = jnp.array([sigm_y1_y2, sigm_y0_1, sigm_y0_3])

        return derivatives, auxiliaries

class DelayedSigmoidalJansenRit(DelayedCoupling):
    """Sigmoidal Jansen-Rit coupling function."""

    N_OUTPUT_STATES = 1
    DEFAULT_PARAMS = Bunch(
        G=1.0,
        cmin=0.0,
        cmax=0.005,  # 2 * 0.0025 from TVB default
        midpoint=6.0,
        r=0.56,
    )

    def pre(
        self, incoming_states: jnp.ndarray, local_states: jnp.ndarray, params: Bunch
    ) -> jnp.ndarray:
        # Extract first two state variables: y1 and y2
        # incoming_states[0] = y1, incoming_states[1] = y2
        # Each has shape [n_nodes_target, n_nodes_source]
        state_diff = incoming_states[0] - incoming_states[1]

        # Apply sigmoidal transformation
        exp_term = jnp.exp(params.r * (params.midpoint - state_diff))
        coupling_term = params.cmin + (params.cmax - params.cmin) / (1.0 + exp_term)

        # Return as [1, n_nodes_target, n_nodes_source] for matrix multiplication
        return coupling_term[jnp.newaxis, :, :]

    def post(
        self, summed_inputs: jnp.ndarray, local_states: jnp.ndarray, params: Bunch
    ) -> jnp.ndarray:
        # Scale summed coupling inputs
        return params.G * summed_inputs

Building the Network Model

To create a whole-brain network model, we combine four key components:

  1. Graph: Defines the structural connectivity (weights) and transmission delays between regions
  2. Dynamics: The local neural mass model at each node (Jansen-Rit)
  3. Coupling: How regions influence each other through connections
  4. Noise: Stochastic fluctuations in neural activity
# Load and normalize structural connectivity
weights, lengths, _ = load_structural_connectivity(name="dk_average")
weights = weights / jnp.max(weights)  # Normalize to [0, 1]
n_nodes = weights.shape[0]

# Convert tract lengths to transmission delays (speed = 3 mm/ms)
speed = 3.0
delays = lengths / speed

# Create network components
graph = DenseDelayGraph(weights, delays, region_labels=region_labels)
dynamics = JansenRit(a = 0.065, b = 0.065, mu = 0.15)
coupling = DelayedSigmoidalJansenRit(incoming_states=["y1", "y2"], G = 15.0)
noise = AdditiveNoise(sigma = 1e-04)

# Assemble the network
network = Network(
    dynamics=dynamics,
    coupling={'delayed': coupling},
    graph=graph,
    noise=noise
)
print_network(network)
 Network Dynamics Network System
==================================================

Dynamics: JansenRit
  States: y0, y1, y2, y3, y4, y5
  Initial: y0=0, y1=5, y2=5, y3=0, y4=0, y5=0

Graph: DenseDelayGraph
  Nodes: 84
  Max delay: 74.73873138427734 ms

Noise: AdditiveNoise
  Apply to: all states
  Params: sigma=0.0001

Couplings
--------------------------------------------------
1. delayed (DelayedSigmoidalJansenRit)
   Type: delayed
   States: incoming=(y1, y2)
   Form: post(Σⱼ wᵢⱼ * (y1ⱼ, y2ⱼ)(t - τᵢⱼ))
   post: 15.0 * (...)
   params: G=15.0, cmin=0.0, cmax=0.005, midpoint=6.0, r=0.56
   Max delay: 74.73873138427734 ms


Dynamics Equations
--------------------------------------------------
    def dynamics(self, t: float, state: jnp.ndarray, params: Bunch, coupling: Bunch, external: Bunch, ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # Unpack parameters
        A, B = params.A, params.B
        a, b = params.a, params.b
        v0, nu_max, r = params.v0, params.nu_max, params.r
        J = params.J
        a_1, a_2, a_3, a_4 = params.a_1, params.a_2, params.a_3, params.a_4
        mu = params.mu

        # Unpack state variables
        y0, y1, y2, y3, y4, y5 = (
            state[0],
            state[1],
            state[2],
            state[3],
            state[4],
            state[5],
        )

        # Unpack coupling inputs
        c_instant = coupling.instant[0]
        # ↳ instant: 0 (not provided)
        c_delayed = coupling.delayed[0]
        # ↳ delayed: post(Σⱼ wᵢⱼ * (y1ⱼ, y2ⱼ)(t - τᵢⱼ))

        # Sigmoid functions
        sigm_y1_y2 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (y1 - y2))))
        sigm_y0_1 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (a_1 * J * y0))))
        sigm_y0_3 = 2.0 * nu_max / (1.0 + jnp.exp(r * (v0 - (a_3 * J * y0))))

        # State derivatives (both couplings add to excitatory interneuron)
        dy0_dt = y3
        dy1_dt = y4
        dy2_dt = y5
        dy3_dt = A * a * sigm_y1_y2 - 2.0 * a * y3 - a**2 * y0
        dy4_dt = (
            A * a * (mu + a_2 * J * sigm_y0_1 + c_instant + c_delayed)
            - 2.0 * a * y4
            - a**2 * y1
        )
        dy5_dt = B * b * (a_4 * J * sigm_y0_3) - 2.0 * b * y5 - b**2 * y2

        # Package results
        derivatives = jnp.array([dy0_dt, dy1_dt, dy2_dt, dy3_dt, dy4_dt, dy5_dt])
        auxiliaries = jnp.array([sigm_y1_y2, sigm_y0_1, sigm_y0_3])

        return derivatives, auxiliaries


Parameters
--------------------------------------------------
  A=3.25, B=22, a=0.065, b=0.065, v0=5.52, nu_max=0.0025, r=0.56, J=135, a_1=1, a_2=0.8, a_3=0.25, a_4=0.25, mu=0.15

Preparing the Simulation

The prepare() function compiles the network into an efficient JAX function and initializes the state. We use the Heun solver (a second-order Runge-Kutta method) for numerical integration. Using solve() runs the simulation dirctly and is usefull for quick prototyping.

# Prepare simulation: compile model and initialize state
t1 = 1000  # Simulation duration (ms)
t1_init = 20000  # Simulation duration (ms)
dt = 1.0   # Integration timestep (ms)
model, state = prepare(network, Heun(), t1 = t1_init, dt = dt)
# result = solve(network, Heun(), t1 = t1_init, dt = dt)

Running Simulations

We run two consecutive simulations: the first allows the network to reach a quasi-stationary state from random initial conditions (transient), and the second continues from the settled state.

# First simulation: transient dynamics
result = model(state)

# Update network with final state as new initial conditions
network.update_history(result)
model, state = prepare(network, Heun(), t1 = t1, dt = dt)

# Second simulation: quasi-stationary dynamics
result2 = model(state)
Figure 4: Network dynamics from transient and quasi-stationary simulations for y0. Top: First 1000 ms showing transient behavior from random initial conditions. Bottom: Dynamics after the network has settled into a quasi-stationary state. Each line represents one brain region, colored by mean activity level.

Spectral Analysis

To compare our simulations with the target MEG frequency gradient, we need to compute power spectral densities (PSDs) for each brain region. We define helper functions using Welch’s method for spectral estimation. Building obervations progessively by enclosing them in functions is a common pattern in JAX.

def spectrum(state):
    """Compute power spectral density for each region using Welch's method."""
    raw = model(state)
    # Subsample by a factor of 10 to get 100 Hz sampling rate
    f, Pxx = jax.scipy.signal.welch(raw.data[::10, 0, :].T, fs=np.round(1000/dt / 10))
    return f, Pxx

def avg_spectrum(state):
    """Compute average power spectrum across all regions."""
    f, Pxx = spectrum(state)
    avg_spectrum = jnp.mean(Pxx, axis=0)
    return f, avg_spectrum

def peak_freq(state):
    """Extract mean peak frequency from average spectrum."""
    f, S = avg_spectrum(state)
    idx = jnp.argmax(S)
    f_max = f[idx]
    return f_max

# Calculate spectra for visualization
f, S = jax.block_until_ready(avg_spectrum(state))
f, Pxx = jax.block_until_ready(spectrum(state))
Figure 5: Power spectral density of initial network dynamics. Gray lines show individual region spectra, while the black line shows the network-average spectrum. With homogeneous parameters (a=0.04, b=0.04 for all regions), the network exhibits a single dominant peak frequency.

Parameter Exploration

Before attempting optimization, it’s useful to understand the relationship between the Jansen-Rit parameters (a and b) and the resulting peak frequency. We perform a grid search over parameter space to map out this relationship.

The Space and GridAxis classes from TVB-Optim make it easy to define parameter grids, and ParallelExecution efficiently evaluates the model across all grid points in parallel.

# Create grid for parameter exploration
n = 32

# Set up parameter axes for exploration
grid_state = copy.deepcopy(state)
grid_state.dynamics.a = GridAxis(low = 0.001, high = 0.2, n = n)
grid_state.dynamics.b = GridAxis(low = 0.001, high = 0.2, n = n)

# Create space (product creates all combinations of a and b)
grid = Space(grid_state, mode = "product")

@cache("explore", redo = False)
def explore():
    # Parallel execution across 8 processes
    exec = ParallelExecution(peak_freq, grid, n_pmap=8)
    # Alternative: Sequential execution
    # exec = SequentialExecution(jax.jit(peak_freq), grid)
    return exec.run()

exploration_result = explore()
Figure 6: Mean Peak Frequenc landscape across parameter space. The heatmap shows how the network’s dominant frequency varies with the excitatory (a) and inhibitory (b) time constants. This landscape guides our optimization by showing which parameter combinations can produce the desired 7-11 Hz frequency range.

Creating Target Spectra

To define optimization targets, we model each region’s desired spectrum as a Cauchy (Lorentzian) distribution centered at the target Mean Peak Frequenc. This gives a realistic spectral shape with a clear peak.

Show code for target spectra generation and plotting
def cauchy_pdf(x, x0, gamma):
    """Cauchy distribution for modeling spectral peaks."""
    return 1 / (np.pi * gamma * (1 + ((x - x0) / gamma) ** 2))

# Create target PSD for each region
target_PSDs = [cauchy_pdf(f, fp, 1) for fp in peak_freqs]
target_PSDs = jnp.vstack(target_PSDs)

# Visualize initial vs target spectra for a subset of regions
fig, ax = plt.subplots(figsize=(8.1, 3.0375))
n_show = 7
for (y, y2) in zip(Pxx[:n_show], target_PSDs[:n_show]):
    ax.plot(f, y, alpha = 0.3, color = "k", linewidth=1.5)
    ax.plot(f, y2, alpha = 0.3, color = "royalblue", linewidth=1.5)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Power')
ax.set_title(f'{n_show} Initial vs Target Power Spectra')
ax.set_xlim(0, 20)
ax.set_yscale('log')
ax.legend(['Initial', 'Target'], loc='upper right')
plt.tight_layout()
Figure 7: Comparison of initial and target spectra. Black lines show the current power spectra from 7 example regions (with homogeneous parameters). Red lines show the corresponding target spectra with region-specific peak frequencies based on the MEG gradient. The goal of optimization is to match the black lines to the red lines for all 84 regions.

Defining the Loss Function

The loss function quantifies how well the model’s spectra match the target spectra. We use correlation between predicted and target power spectra as our similarity metric. The loss is defined as 1 minus the mean correlation across all regions, so minimizing the loss maximizes spectral similarity.

def loss(state):
    """
    Compute loss as 1 - mean correlation between predicted and target PSDs.

    Lower loss means better match between model and target spectra.
    """
    # Compute the power spectral density from the current state
    frequencies, predicted_psd = spectrum(state)

    # Calculate correlation coefficient between each predicted PSD and its target
    # vmap applies the correlation function across all PSD pairs in parallel
    correlations = jax.vmap(
        lambda predicted, target: jnp.corrcoef(predicted, target)[0, 1]
    )(predicted_psd, target_PSDs)

    # Return 1 minus the mean correlation as the loss
    # (we want to minimize this, which maximizes correlation)
    average_correlation = correlations.mean()
    aux_data = predicted_psd # Used for plotting during optimization
    return (1 - average_correlation), aux_data

# Test loss function
loss(state)
(Array(0.56804025, dtype=float32),
 Array([[1.23043878e-06, 8.56911015e-07, 1.09341336e-07, ...,
         5.36267752e-10, 1.18968435e-09, 4.82776319e-10],
        [1.92068295e-09, 8.76787354e-08, 4.32468795e-07, ...,
         6.74991441e-09, 2.40194020e-09, 4.37712255e-10],
        [8.12171322e-07, 3.76594755e-07, 4.69840309e-08, ...,
         1.66150749e-09, 5.22748289e-09, 1.13152709e-10],
        ...,
        [9.44446981e-07, 3.93654688e-07, 2.93623756e-08, ...,
         7.72786946e-10, 1.87371363e-09, 1.99487360e-09],
        [4.91253616e-07, 6.68663276e-08, 2.23511307e-07, ...,
         1.85898976e-08, 2.67928311e-08, 1.62876699e-08],
        [5.00447825e-08, 1.79569227e-07, 1.17351640e-07, ...,
         1.68474514e-08, 9.72651559e-09, 1.17871410e-08]], dtype=float32))

Running Gradient-Based Optimization

Now we use gradient-based optimization to find region-specific parameters a and b that produce the target frequency gradient. We mark parameters as optimizable using the Parameter class and set their shape to be region-specific (heterogeneous).

TVB-Optim integrates with optax for optimization, providing automatic differentiation through JAX. We use the AdaMax optimizer with callbacks for progress monitoring.

# Mark parameters as optimizable and make them heterogeneous (region-specific)
init_state = copy.deepcopy(state)
init_state.dynamics.a = Parameter(init_state.dynamics.a)
init_state.dynamics.a.shape = (n_nodes,)  # One value per region
init_state.dynamics.b = Parameter(init_state.dynamics.b)
init_state.dynamics.b.shape = (n_nodes,)  # One value per region

# Set up callbacks for monitoring
cb = MultiCallback([
    DefaultPrintCallback(every=10),  # Print progress
    PlotCallback(every=10),
])

# Create optimizer
opt = OptaxOptimizer(loss, optax.adamaxw(0.001), callback = cb, has_aux=True)

@cache("optimize", redo = False)
def optimize():
    fitted_params, fitting_data = opt.run(init_state, max_steps=151)
    return fitted_params, fitting_data

fitted_params, fitting_data = optimize()
Loading optimize from cache, last modified 2025-12-17 15:27:37.299641

Optimization Results: Parameter Distribution

Let’s visualize where the optimizer placed each region’s parameters in the frequency landscape we explored earlier.

Figure 8: Optimized parameters overlaid on the frequency landscape. Each white dot represents the fitted (a, b) parameters for one brain region. The optimizer has distributed regions across parameter space to achieve their target frequencies in the 7-11 Hz range.

Optimization Results: Frequency Maps

Finally, let’s visualize the fitted frequency gradient in brain space and compare it to the initial homogeneous state.

Figure 9: Comparison of initial and fitted frequency gradients. Left: Initial homogeneous network showing similar frequencies across all regions. Right: Fitted frequency gradient matching the target MEG pattern, with lower frequencies in association areas and higher frequencies in sensory cortex.

Summary

This tutorial demonstrated the complete workflow for fitting brain network models using TVB-Optim:

  1. Model Construction: We built a whole-brain network using the Jansen-Rit neural mass model with structural connectivity and transmission delays.

  2. Target Definition: We defined region-specific target frequencies based on the empirical MEG gradient, creating synthetic target spectra using Cauchy distributions.

  3. Parameter Exploration: We mapped out the relationship between model parameters and oscillation frequencies using parallel grid search.

  4. Gradient-Based Optimization: We used automatic differentiation and the AdaMax optimizer to fit heterogeneous (region-specific) parameters that reproduce the target frequency gradient.

  5. Validation: We visualized the results showing successful reproduction of the spatial frequency gradient across brain regions.

This approach showcases key features of TVB-Optim including the modular Network architecture, JAX-based automatic differentiation for gradient-based optimization, and integration with visualization tools for neuroimaging data.