Define a Brain Network Model in TVB-Optim, TVB-O, or TVB

This page provides installation instructions and guides you through the TVB-Optim workflow. We follow a parallel structure depending on how you want to define your brain network model.

Installation & Requirements

TVB-Optim requires Python 3.11 or later and depends on JAX for high-performance computing and automatic differentiation.

Install TVB-Optim

TVB-Optim is a standalone package that provides utilities for optimization algorithms, parameter spaces, and execution strategies. You can define models directly in TVB-Optim or use models from other frameworks:

uv pip install tvboptim

For development:

git clone https://github.com/virtual-twin/tvboptim.git
cd tvboptim
uv pip install -e ".[dev]"
pip install tvboptim

For development:

git clone https://github.com/virtual-twin/tvboptim.git
cd tvboptim
pip install -e ".[dev]"

Optional: Install TVB-O

TVB-O is only required if you want to: - Use models defined in classic TVB (The Virtual Brain) - Access models from the TVB ontology - Utilize pre-defined brain connectivity data and atlases

If you need TVB-O, install it with:

uv pip install tvbo

For development:

git clone https://github.com/virtual-twin/tvbo.git
cd tvboptim
uv pip install -e ".[dev]"
pip install tvbo

For development:

git clone https://github.com/virtual-twin/tvbo.git
cd tvboptim
pip install -e ".[dev]"
Imports
# Set up environment
import os
import time
import copy

# Mock devices to force JAX to parallelize on CPU (pmap trick)
# This allows parallel execution even without multiple GPUs
cpu = True
if cpu:
    N = 8  # Number of virtual devices to create
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'

# Import all required libraries
from scipy import io
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import copy
import optax  # JAX-based optimization library
from IPython.display import Markdown

# Import from tvboptim - our optimization and execution framework
from tvboptim import prepare  # Converts TVB-O experiments to JAX functions
from tvboptim.types import Parameter  # Parameter types
from tvboptim.types.spaces import Space, GridAxis  # Parameter spaces
from tvboptim.types.stateutils import show_parameters  # Utility functions
from tvboptim.utils import set_cache_path, cache  # Caching for expensive computations
# from tvboptim.observations import   # Observation functions (FC, RMSE, etc.)
from tvboptim.execution import ParallelExecution, SequentialExecution  # Execution strategies
from tvboptim.optim.optax import OptaxOptimizer  # JAX-based optimizer with automatic differentiation
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, StopLossCallback  # Optimization callbacks

# Import network_dynamics from tvboptim.experimental - for standalone TVB-Optim models
from tvboptim.experimental import network_dynamics as nd
from tvboptim.experimental.network_dynamics.dynamics import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.data import load_structural_connectivity

# Import from tvbo - the brain simulation framework
from tvbo.export.experiment import SimulationExperiment  # Main experiment class
from tvbo.datamodel import tvbo_datamodel  # Data structures
from tvbo.utils import numbered_print  # Utility functions

# Classic TVB imports - for using TVB simulator directly
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from tvb.simulator.lab import simulator, models, connectivity, coupling, integrators, monitors

# Set cache path for tvboptim - stores expensive computations for reuse
set_cache_path("./get_started")
Load Data
# Load structural connectivity data from Desikan-Killiany atlas
# Returns: weights (connectivity matrix), lengths (tract lengths in mm), labels (region names)
weights, lengths, labels = load_structural_connectivity("dk_average")

# Normalize weights to [0, 1] range
weights = weights / np.max(weights)

# Calculate delays from tract lengths and conduction speed (3 mm/ms)
speed = np.inf
delays = lengths / speed

Create a Brain Network Model

# Define the local dynamics for each node: Reduced Wong-Wang model
# w: excitatory recurrence strength, I_o: external input, INITIAL_STATE: starting condition
dynamics = ReducedWongWang(w=0.5, I_o=0.32, INITIAL_STATE=(0.3,))

# Define coupling between nodes: Linear coupling
# incoming_states: which state variable to couple, G: global coupling strength
coup = LinearCoupling(incoming_states='S', G=0.75)

# Define the network structure: connectivity weights and delays
graph = DenseDelayGraph(weights, delays)

# Add noise to the system: additive noise applied to state variable S
noise = AdditiveNoise(sigma=0.002, apply_to="S", key=jax.random.key(42))

# Create the network by combining all components
# The coupling is labeled "delayed" (vs "instant" for instantaneous coupling)
network = nd.Network(dynamics, {"delayed": coup}, graph, noise=noise)

# Solve the network dynamics using Heun stochastic integrator
# t0=0, t1=10000 (ms), dt=4.0 (integration step size in ms)
result = nd.solve(network, Heun(), t0=0, t1=10_000, dt=4.0)

# Plot time series: all regions (columns) over time (rows)
plt.plot(result.ts, result.ys[:, 0, :])
plt.xlabel("Time (ms)")
plt.ylabel("S (Synaptic gating)")
plt.title("TVB-Optim Network Dynamics");

For all the details on TVB-O, see its Documentation. A simple experiment can be created like this:

# Create a brain simulation experiment using the Reduced Wong-Wang model
# TVB-O uses a declarative dictionary-based approach
experiment = SimulationExperiment(
    model={
        "name": "ReducedWongWang",
        "parameters": {
            "w": {"name": "w", "value": 0.5},      # Excitatory recurrence strength
            "I_o": {"name": "I_o", "value": 0.32}, # External input current
        },
        "state_variables": {
            "S": {"initial_value": 0.3},  # Initial synaptic gating variable
        }
    },
    connectivity={
        "conduction_speed": {"name": "cs", "value": np.array([speed])}  # 3 mm/ms
    },
    coupling={
        "name": "Linear",
        "parameters": {"a": {"name": "a", "value": 0.75}}  # Global coupling strength
    },
    integration={
        "method": "Heun",        # Stochastic Heun integration
        "step_size": 4.0,        # Integration step in ms
        "noise": {"parameters": {"sigma": {"value": 0.002}}},  # Noise amplitude
        "duration": 10_000       # Total simulation time in ms
    },
    monitors={
        "Raw": {"name": "Raw"},  # Monitor raw state variables
    },
)

# Set the connectivity data loaded earlier
experiment.connectivity.weights = weights
experiment.connectivity.lengths = lengths
experiment.connectivity.metadata.number_of_regions = weights.shape[0]

# Configure the experiment (validates and prepares for simulation)
experiment.configure()

# Run the simulation and get results
result_tvbo = experiment.run(format="jax")

# Plot time series: all regions over time
plt.plot(result_tvbo.time, result_tvbo.data[:, 0, :, 0])
plt.xlabel("Time (ms)")
plt.ylabel("S (Synaptic gating)")
plt.title("TVB-O Experiment Results");
Markdown(experiment.model.generate_report())
numbered_print(experiment.render_code(format = "jax"))

Classic simulator definition in TVB:

# Define simulator using classic TVB (object-oriented approach)
sim = simulator.Simulator(
    # Model with parameters as numpy arrays
    model=models.ReducedWongWang(
        w=np.array([0.5]),      # Excitatory recurrence
        I_o=np.array([0.32])    # External input
    ),
    # Connectivity object with structural data
    connectivity=connectivity.Connectivity(
        weights=np.array(weights),
        tract_lengths=np.array(lengths),
        region_labels=np.array(labels),
        centres=np.zeros(weights.shape[0]),  # Dummy spatial coordinates
    ),
    # Linear coupling function
    coupling=coupling.Linear(a=np.array([0.75])),
    # Stochastic Heun integrator with additive noise
    integrator=integrators.HeunStochastic(
        noise=integrators.noise.Additive(nsig=np.array([0.5 * 0.002**2])),
        dt=4.0
    ),
    # Initial conditions: shape (history_length, state_vars, regions, modes)
    initial_conditions=0.3 * np.ones((50, 1, 84, 1)),
    simulation_length=10_000,  # Total time in ms
    monitors=(monitors.Raw(),)  # Record raw state variables
)

# Configure the simulator (sets up internal structures)
sim.configure()

# Convert classic TVB simulator to TVB-O experiment format
experiment_tvb = SimulationExperiment.from_tvb_simulator(sim)

# Run using JAX backend
result_tvb = experiment_tvb.run(format="jax")

# Plot time series: all regions over time
plt.plot(result_tvb.Raw.time, result_tvb.Raw.data[:, 0, :, 0])
plt.xlabel("Time (ms)")
plt.ylabel("S (Synaptic gating)")
plt.title("Classic TVB Simulator Results");

Get Model and State

The prepare function separates network setup from execution, returning a pure function for simulation:

# Prepare the network for simulation - returns (model_fn, config)
# This separates setup from execution for better performance and composability
model, state = nd.prepare(network, Heun(), t0=0, t1=10_000, dt=4.0)

The model is a pure JAX function that takes the config and runs the simulation. The state contains all parameters, initial conditions, and precomputed data needed for the simulation.

The prepare function converts the TVB-O experiment into a JAX-compatible model function and state object:

# Convert TVB-O experiment to JAX function and state
model_tvbo, state_tvbo = prepare(experiment)

The model is a pure JAX function that takes the config and runs the simulation. The state contains all parameters, initial conditions, and precomputed data needed for the simulation.

Understand the State Object & Parameters

The state is of type tvbo.datamodel.tvbo_datamodel.Bunch, which is a dict with convenient get and set functions. At the same time, it is also a jax.Pytree, making it compatible with all of JAX’s transformations. You can think of it as a big tree holding all parameters and initial conditions that uniquely define a simulation:

state.keys()
dict_keys(['dynamics', 'coupling', 'external', 'graph', 'initial_state', '_internal', 'noise'])
state_tvbo

Simulate the Model

To run a simulation, you can simply call the model.

# Run the simulation - model returns result object
result = model(state)
Visualization
# Plot raw neural activity for all 84 brain regions over time
# Shape: [time_points, state_variables, regions]
plt.plot(result.data[:,0,:], color = "royalblue", alpha = 0.5)
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Raw Neural Activity Across All Brain Regions");

# Run the simulation - model returns (raw_activity, )
result = model_tvbo(state_tvbo)
raw = result[0]
Visualization
# Plot raw neural activity for all 84 brain regions over time
# Shape: [time_points, state_variables, regions, modes]
plt.plot(raw.data[:,0,:,0], color = "royalblue", alpha = 0.5)
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Raw Neural Activity Across All Brain Regions");

Wrap the Model to create observations

We look at the mean activity of the last 500 timesteps as an easy observation.

def observation(state):
    """
    Extract a simple observation from the simulation.
    We use the mean activity of the last 500 timesteps to avoid transient effects.
    """
    ts = model(state).data[-500:,0,:]  # Last 500 timesteps, skip transient
    mean_activity = jnp.mean(ts)  # Average across time and regions
    return mean_activity

# Test the observation function
print(f"Mean activity: {observation(state):.4f}")
Mean activity: 0.6910
def observation_tvbo(state):
    """
    Extract a simple observation from the simulation.
    We use the mean activity of the last 500 timesteps to avoid transient effects.
    """
    ts = model_tvbo(state)[0].data[-500:,0,:,0]  # Last 500 timesteps, skip transient
    mean_activity = jnp.mean(ts)  # Average across time and regions
    return mean_activity

# Test the observation function
print(f"Mean activity: {observation_tvbo(state_tvbo):.4f}")

Explore that across a parameter space

We can use a Space to explore how parameters J_N (excitatory recurrence) and a (global coupling) affect the observation. We use the cache decorator to save computationally demanding operations. We also parallelize the exploration using the ParallelExecution class with n_pmap = 8, which is possible because we told JAX that our CPU has 8 devices - known as the pmap trick.

# Create a copy of the state for parameter exploration
exploration_state = copy.deepcopy(state)

# Set up parameter exploration by setting GridAxis in the state
n = 32  # 32x32 grid = 1024 parameter combinations
exploration_state.coupling.delayed.G = GridAxis(0.0, 1.0, n)
exploration_state.dynamics.J_N = GridAxis(0.0, 0.5, n)

# Create a grid space for systematic parameter exploration
params_space = Space(exploration_state, mode="product")

@cache("explore_tvboptim", redo = False)  # Cache results to avoid recomputation
def explore():
    # Use parallel execution with 8 virtual devices (pmap trick)
    exec = ParallelExecution(observation, params_space, n_pmap=8)
    return exec.run()

# Run the exploration (or load from cache)
exploration = explore()
Loading explore_tvboptim from cache, last modified 2025-12-17 15:27:36.942677
Visualization
# Collect parameter values and results
coupling_vals = []
j_n_vals = []
activity_vals = []
for params, activity in zip(params_space, exploration):
    coupling_vals.append(params.coupling.delayed.G)
    j_n_vals.append(params.dynamics.J_N)
    activity_vals.append(activity)

# Convert to arrays
coupling_vals = jnp.array(coupling_vals)
j_n_vals = jnp.array(j_n_vals)
activity_vals = jnp.array(activity_vals)

# Create properly gridded heatmap
# Get unique sorted values for each parameter
unique_coupling = jnp.sort(jnp.unique(coupling_vals))
unique_j_n = jnp.sort(jnp.unique(j_n_vals))

# Create grid and map values
grid = jnp.zeros((len(unique_j_n), len(unique_coupling)))
for coupling, j_n, activity in zip(coupling_vals, j_n_vals, activity_vals):
    i = jnp.where(unique_j_n == j_n)[0][0]
    j = jnp.where(unique_coupling == coupling)[0][0]
    grid = grid.at[i, j].set(activity)

# Visualize the parameter space exploration
plt.figure(figsize=(8.1, 6.48))
im = plt.imshow(grid, aspect="auto", origin="lower",
                extent=[unique_coupling[0], unique_coupling[-1],
                       unique_j_n[0], unique_j_n[-1]],
                cmap="cividis")
plt.xlabel("Global Coupling (G)")
plt.ylabel("Recurrent Excitation (J_N) [nA]")
plt.title("Mean Activity Across Parameter Space")
plt.colorbar(im, label="Mean Activity")
plt.grid(True, alpha=0.2, color='white')

# Create a copy of the state for parameter exploration
exploration_state_tvbo = copy.deepcopy(state_tvbo)

# Set up parameter exploration by setting GridAxis in the state
n = 32  # 32x32 grid = 1024 parameter combinations
exploration_state_tvbo.parameters.coupling.a = GridAxis(0.0, 1.0, n)
exploration_state_tvbo.parameters.model.J_N = GridAxis(0.0, 0.5, n)

# Create a grid space for systematic parameter exploration
params_space_tvbo = Space(exploration_state_tvbo, mode="product")

@cache("explore_tvbo", redo = False)  # Cache results to avoid recomputation
def explore():
    # Use parallel execution with 8 virtual devices (pmap trick)
    exec = ParallelExecution(observation_tvbo, params_space_tvbo, n_pmap=8)
    return exec.run()

# Run the exploration (or load from cache)
exploration = explore()
Visualization
# Collect parameter values and results
coupling_vals_tvbo = []
j_n_vals_tvbo = []
activity_vals_tvbo = []
for params, activity in zip(params_space_tvbo, exploration):
    coupling_vals_tvbo.append(params.parameters.coupling.a)
    j_n_vals_tvbo.append(params.parameters.model.J_N)
    activity_vals_tvbo.append(activity)

# Convert to arrays
coupling_vals_tvbo = jnp.array(coupling_vals_tvbo)
j_n_vals_tvbo = jnp.array(j_n_vals_tvbo)
activity_vals_tvbo = jnp.array(activity_vals_tvbo)

# Create properly gridded heatmap
# Get unique sorted values for each parameter
unique_coupling_tvbo = jnp.sort(jnp.unique(coupling_vals_tvbo))
unique_j_n_tvbo = jnp.sort(jnp.unique(j_n_vals_tvbo))

# Create grid and map values
grid_tvbo = jnp.zeros((len(unique_j_n_tvbo), len(unique_coupling_tvbo)))
for coupling, j_n, activity in zip(coupling_vals_tvbo, j_n_vals_tvbo, activity_vals_tvbo):
    i = jnp.where(unique_j_n_tvbo == j_n)[0][0]
    j = jnp.where(unique_coupling_tvbo == coupling)[0][0]
    grid_tvbo = grid_tvbo.at[i, j].set(activity)

# Visualize the parameter space exploration
plt.figure(figsize=(10, 8))
im = plt.imshow(grid_tvbo, aspect="auto", origin="lower",
                extent=[unique_coupling_tvbo[0], unique_coupling_tvbo[-1],
                       unique_j_n_tvbo[0], unique_j_n_tvbo[-1]],
                cmap="cividis")
plt.xlabel("Global Coupling (a)")
plt.ylabel("Recurrent Excitation (J_N) [nA]")
plt.title("Mean Activity Across Parameter Space")
plt.colorbar(im, label="Mean Activity")
plt.grid(True, alpha=0.2, color='white')

Define a Loss and Optimize

Let’s say our goal is to have a mean activity of 0.5. We can define a loss function that penalizes deviations from this target.

def loss(state):
    """
    Define a loss function that penalizes deviations from target activity.
    Goal: Each brain region should have mean activity of 0.5
    """
    ts = model(state).data[-500:,0,:]  # Skip transient period
    mean_activity = jnp.mean(ts, axis = 0)  # Average over time for each region
    # Compute mean squared error between actual and target (0.5) activity
    return jnp.mean((mean_activity - 0.5)**2)  # Region-wise difference

# Test the loss function
print(f"Current loss: {loss(state):.6f}")
Current loss: 0.098067
def loss_tvbo(state):
    """
    Define a loss function that penalizes deviations from target activity.
    Goal: Each brain region should have mean activity of 0.5
    """
    ts = model_tvbo(state)[0].data[-500:,0,:,0]  # Skip transient period
    mean_activity = jnp.mean(ts, axis = 0)  # Average over time for each region
    # Compute mean squared error between actual and target (0.5) activity
    return jnp.mean((mean_activity - 0.5)**2)  # Region-wise difference

# Test the loss function
print(f"Current loss: {loss_tvbo(state_tvbo):.6f}")

Values in the state are JAX arrays that can be marked for optimization. When you explicitly wrap a value in the Parameter type, it becomes available for gradients during optimization. The Parameter system provides seamless JAX integration with Optax providing the optimizer implementation.

# Mark the excitatory recurrence parameter as free for optimization
state.dynamics.J_N = Parameter(state.dynamics.J_N)
show_parameters(state)  # Display all parameters available for optimization

# Create an optimizer using Adam with automatic differentiation
optimizer = OptaxOptimizer(
    loss,                           # Loss function to minimize
    optax.adam(0.005),             # Adam optimizer with learning rate 0.005
    callback=DefaultPrintCallback(every=5) # Print progress during optimization
)

# Run optimization using forward-mode automatic differentiation
# Forward mode is efficient when we have few parameters (like here: a and J_N)
%time optimized_state, _ = optimizer.run(state, max_steps=50, mode="fwd")
Parameters
├── _internal: Bunch
├── coupling: Bunch
├── dynamics
│   ├── J_N
│   │   └── value: 0.26089999079704285
├── external: Bunch
├── graph: DenseDelayGraph
├── initial_state: Bunch
└── noise: Bunch
Step 0: 0.098067
Step 5: 0.091565
Step 10: 0.086876
Step 15: 0.088758
Step 20: 0.089644
Step 25: 0.089902
Step 30: 0.089752
Step 35: 0.089331
Step 40: 0.088722
Step 45: 0.087983
CPU times: user 10.3 s, sys: 223 ms, total: 10.5 s
Wall time: 8.19 s
# Mark the excitatory recurrence parameter as free for optimization
state_tvbo.parameters.model.J_N = Parameter(state_tvbo.parameters.model.J_N)
show_parameters(state_tvbo)  # Display all parameters available for optimization

# Create an optimizer using Adam with automatic differentiation
optimizer_tvbo = OptaxOptimizer(
    loss_tvbo,                           # Loss function to minimize
    optax.adam(0.005),             # Adam optimizer with learning rate 0.005
    callback=DefaultPrintCallback(every=5) # Print progress during optimization
)

# Run optimization using forward-mode automatic differentiation
# Forward mode is efficient when we have few parameters (like here: a and J_N)
%time optimized_state_tvbo, _ = optimizer_tvbo.run(state_tvbo, max_steps=50, mode="fwd")

Visualize the Fitted Model

# Simulate with optimized parameters
ts_optimized = model(optimized_state).data[:,0,:]
Visualization
plt.figure(figsize=(10, 6))
plt.plot(ts_optimized, alpha = 0.5, color = "royalblue")
plt.hlines(0.5, 0, 2500, color = "black", linewidth = 2.5, label="Target (0.5)")
plt.hlines(observation(optimized_state), 0, 2500, color = "red", linewidth = 2.5,
           label=f"Actual Mean ({observation(optimized_state):.3f})")
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Optimized Neural Activity")
plt.legend()
plt.grid(True, alpha=0.3)

# Simulate with optimized parameters
ts_optimized_tvbo = model_tvbo(optimized_state_tvbo)[0].data[:,0,:,0]
Visualization
plt.figure(figsize=(10, 6))
plt.plot(ts_optimized_tvbo, alpha = 0.5, color = "royalblue")
plt.hlines(0.5, 0, 2500, color = "black", linewidth = 2.5, label="Target (0.5)")
plt.hlines(observation_tvbo(optimized_state_tvbo), 0, 2500, color = "red", linewidth = 2.5,
           label=f"Actual Mean ({observation_tvbo(optimized_state_tvbo):.3f})")
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Optimized Neural Activity")
plt.legend()
plt.grid(True, alpha=0.3)

Well, the mean is close to the target, but most regions are either too high or too low. We can make parameters heterogeneous to adjust that.

Heterogeneous Parameters

The previous optimization used global parameters (same value for all brain regions). Now we’ll make parameters region-specific to achieve better control.

We switch to reverse mode automatic differentiation, which is more efficient when we have many parameters (84 parameters):

# Make parameters heterogeneous: one value per brain region (84 regions)
optimized_state.dynamics.J_N.shape = (weights.shape[0],)  # Excitatory recurrence per region
print(f"J_N parameter shape: {optimized_state.dynamics.J_N.shape}")

# Create optimizer for heterogeneous parameters
optimizer_het = OptaxOptimizer(
    loss,                                    # Same loss function
    optax.adam(0.005),                      # Lower learning rate for stability
    callback=MultiCallback([DefaultPrintCallback(every=10), StopLossCallback(stop_loss=0.005)]) # Print every 10 steps
)

# Use reverse-mode AD (more efficient for many parameters)
optimized_state_het, _ = optimizer_het.run(optimized_state, max_steps=201, mode="rev")
J_N parameter shape: (84,)
Step 0: 0.087251
Step 10: 0.055372
Step 20: 0.028170
Step 30: 0.010199
Step 40: 0.010310
Step 50: 0.025366
Step 60: 0.044674
Step 70: 0.050680
Step 80: 0.052546
Step 90: 0.052947
Step 100: 0.052797
Step 110: 0.052423
Step 120: 0.051948
Step 130: 0.051414
Step 140: 0.050837
Step 150: 0.050223
Step 160: 0.049572
Step 170: 0.048885
Step 180: 0.048160
Step 190: 0.047396
Step 200: 0.046594
# Make parameters heterogeneous: one value per brain region (84 regions), one mode
optimized_state_tvbo.parameters.model.J_N.shape = (weights.shape[0],1)  # Excitatory recurrence per region
print(f"J_N parameter shape: {optimized_state_tvbo.parameters.model.J_N.shape}")

# Create optimizer for heterogeneous parameters
optimizer_het_tvbo = OptaxOptimizer(
    loss_tvbo,                              # Same loss function
    optax.adam(0.005),                      # Lower learning rate for stability
    callback=MultiCallback([DefaultPrintCallback(every=10), StopLossCallback(stop_loss=0.005)]) # Print every 10 steps
)

# Use reverse-mode AD (more efficient for many parameters)
optimized_state_het_tvbo, _ = optimizer_het_tvbo.run(optimized_state_tvbo, max_steps=201, mode="rev")

Now most regions are close to the target level after passing the initial transient. Setting the initial conditions to the target activity could be a solution to this problem.

# Simulate with heterogeneous parameters
ts_optimized_het = model(optimized_state_het).data[:,0,:]
Visualization
plt.figure(figsize=(15, 10))

# Plot 1: Time series for all regions
plt.subplot(2, 1, 1)
plt.plot(ts_optimized_het, alpha = 0.5, color = "royalblue")
plt.hlines(0.5, 0, 2500, color = "black", linewidth = 2.5, label="Target (0.5)")
plt.hlines(observation(optimized_state_het), 0, 2500, color = "red", linewidth = 2.5,
           label=f"Mean ({observation(optimized_state_het):.3f})")
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Heterogeneous Optimization")
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: J_N parameters vs mean regionwise coupling
mean_coupling = jnp.mean(weights, axis=1)
plt.subplot(2, 1, 2)
plt.scatter(mean_coupling, optimized_state_het.dynamics.J_N.value.flatten(), alpha=0.7, color="k", s=30)
plt.xlabel(r"Mean Regionwise Coupling $\sum_i C_{ij}$")
plt.ylabel("J_N [nA]")
plt.title("Fitted J_N Parameters")
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Simulate with heterogeneous parameters
ts_optimized_het_tvbo = model_tvbo(optimized_state_het_tvbo)[0].data[:,0,:,0]
Visualization
plt.figure(figsize=(15, 10))

# Plot 1: Time series for all regions
plt.subplot(2, 1, 1)
plt.plot(ts_optimized_het_tvbo, alpha = 0.5, color = "royalblue")
plt.hlines(0.5, 0, 2500, color = "black", linewidth = 2.5, label="Target (0.5)")
plt.hlines(observation_tvbo(optimized_state_het_tvbo), 0, 2500, color = "red", linewidth = 2.5,
           label=f"Mean ({observation_tvbo(optimized_state_het_tvbo):.3f})")
plt.xlabel("Time (ms)")
plt.ylabel("Neural Activity")
plt.title("Heterogeneous Optimization")
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: J_N parameters vs mean regionwise coupling
mean_coupling = jnp.mean(experiment.connectivity.weights, axis=1)
plt.subplot(2, 1, 2)
plt.scatter(mean_coupling, optimized_state_het_tvbo.parameters.model.J_N.value.flatten(), alpha=0.7, color="k", s=30)
plt.xlabel(r"Mean Regionwise Coupling $\sum_i C_{ij}$")
plt.ylabel("J_N [nA]")
plt.title("Fitted J_N Parameters")
plt.grid(True, alpha=0.3)

plt.tight_layout()

Why is this problem interesting? The Reduced Wong Wang model has two fixed point branches - low activity (~0.1) and high activity (~0.9). Each region tends to approach one of them, but for the desired target level, we need to find a balance between the two. This concept is also known as feedback inhibition control (FIC).

Key Concepts Demonstrated

This tutorial showcased several important TVB-Optim concepts:

  1. Parameter Types: The Parameter class wraps JAX arrays with additional functionality and constraints.
  2. State Management: The state object is a JAX PyTree containing all simulation parameters and initial conditions
  3. Spaces: Space enables systematic parameter exploration by flexibly placing Axis objects
  4. Execution Strategies: ParallelExecution leverages JAX’s pmap for efficient computation across parameter sets
  5. Optimization: OptaxOptimizer provides gradient-based optimization with automatic differentiation
  6. Caching: The @cache decorator saves expensive computations for reuse
  7. Heterogeneous Parameters: Region-specific parameters enable fine-grained control over brain dynamics

These tools enable efficient exploration and optimization of complex brain network models at scale.