WarningExperimental Feature

The interface to design network dynamics is still under development and might change in future releases.

What is Network Dynamics?

Network Dynamics is a pure JAX interface for defining and simulating brain network models. If you’re familiar with TVB (The Virtual Brain), you’ll recognize the core concepts, dynamics models, structural connectivity, coupling functions, and numerical integration. The key difference is that Network Dynamics is built entirely on JAX, making it:

  • Differentiable: Gradients flow through the entire simulation for parameter optimization
  • JIT-compilable: Automatic compilation for high-performance execution
  • GPU-ready: Support for GPU acceleration
  • Explicit: Clear separation of concerns with minimal magic

This document provides a gentle introduction through a complete working example, comparing the traditional TVB approach with the Network Dynamics framework. For detailed documentation on each component, see the linked sections throughout.

A Complete Example: TVB vs Network Dynamics

Let’s simulate a resting-state brain network using the Reduced Wong-Wang model with empirical connectivity. We’ll show both approaches side-by-side so you can see the correspondence between frameworks.

Network Dynamics implements TVB’s standard mathematical framework for brain network simulations. While this framework is the default, Network Dynamics can in principle model other system architectures due to its flexible JAX-based design.

The standard TVB framework simulates systems of stochastic differential equations with the following structure:

\[ \begin{align*} dS_i &= \left[f_d(S_i, \theta^d, C_i, I_i) \right]dt + g(S_i, \theta^g)\, dW_i \\ C_i &= f_c^{\text{post}}\left(\sum_j A_{ij}\, f_c^{\text{pre}}(S_i, S_j(t-\tau_{ij}), \theta^c), S_i, \theta^c\right) \end{align*} \]

State Evolution (first equation):

  • \(S_i\) - state variables at node \(i\) (e.g., membrane potential, synaptic gating)
  • \(f_d\) - dynamics function defining local temporal evolution with parameters \(\theta^d\)
  • \(C_i\) - coupling input from connected nodes
  • \(I_i\) - external input (stimulation, driving signals)
  • \(g\) - diffusion coefficient controlling noise intensity with parameters \(\theta^g\)
  • \(dW_i\) - Wiener process (Brownian motion) for stochastic fluctuations

Coupling (second equation):

  • \(f_c^{\text{pre}}\) - pre-aggregation transformation (e.g., state differences, nonlinearities)
  • \(A_{ij}\) - structural connectivity weight from node \(j\) to node \(i\)
  • \(\tau_{ij}\) - transmission delay from tract length and conduction speed
  • \(f_c^{\text{post}}\) - post-aggregation transformation (e.g., gain, offset)
  • \(\theta^c\) - coupling parameters (strength, thresholds, etc.)

Each code component in Network Dynamics directly implements one part of this mathematical structure, making the mapping from theory to implementation transparent.

TVB Framework

Imports
import warnings
import numpy as np

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from tvb.simulator.lab import (
        models, connectivity, coupling,
        integrators, monitors, simulator
    )

from tvboptim.data import (
    load_structural_connectivity
)
# Load connectivity data
weights, lengths, labels = (
    load_structural_connectivity("dk_average")
)

# Normalize weights
weights = weights / np.max(weights)

# Create connectivity object
conn = connectivity.Connectivity(
    weights=np.array(weights),
    tract_lengths=np.array(lengths),
    region_labels=np.array(labels),
    centres = np.zeros(84),
    speed=np.array([3.0])
)

# Create dynamics (defaults used)
model = models.ReducedWongWang(
    w=np.array([0.7])
)


# Create coupling
coupl = coupling.Linear(
    a=np.array([0.5])
)

# Create noise
# Note: TVB uses nsig = 0.5 * sigma^2
# where sigma is the desired std
sigma = 0.01
sigma_tvb = 0.5 * sigma**2
integrator = integrators.HeunStochastic(
    dt=1.0,
    noise=integrators.noise.Additive(
        nsig=np.array([sigma_tvb])
    )
)

# Build simulator
sim = simulator.Simulator(
    model=model,
    connectivity=conn,
    coupling=coupl,
    integrator=integrator,
    initial_conditions=0.1 * np.ones((100,1,84,1)),
)

sim.configure()

# Run simulation
print("Running TVB...")
(time_tvb, data_tvb), = sim.run(
    simulation_length=1000.0
)

print(f"Shape: {data_tvb.shape}")
Running TVB...
Shape: (1000, 1, 84, 1)

Network Dynamics Framework

Imports
import jax
import jax.numpy as jnp
from tvboptim.experimental.network_dynamics import (
    Network, solve
)
from tvboptim.experimental.network_dynamics.dynamics.tvb import (
    ReducedWongWang
)
from tvboptim.experimental.network_dynamics.coupling import (
    DelayedLinearCoupling
)
from tvboptim.experimental.network_dynamics.graph import (
    DenseDelayGraph
)
from tvboptim.experimental.network_dynamics.noise import (
    AdditiveNoise
)
from tvboptim.experimental.network_dynamics.solvers import (
    Heun
)
from tvboptim.data import (
    load_structural_connectivity
)
# Load connectivity data
weights, lengths, labels = (
    load_structural_connectivity("dk_average")
)

# Normalize weights
weights = weights / jnp.max(weights)
# Compute delays from tract lengths
delays = lengths / 3.0

# Create graph (connectivity + delays)
graph = DenseDelayGraph(
    weights, delays, region_labels=labels
)



# Create dynamics (defaults used)
dynamics = ReducedWongWang(
    INITIAL_STATE=(0.1),
    w=0.7
)

# Create coupling
coupling = DelayedLinearCoupling(
    incoming_states='S', G=0.5
)

# Create noise
# Note: Network Dynamics uses sigma
# as the std of additive noise
noise = AdditiveNoise(
    sigma=0.01, key=jax.random.key(42)
)






# Build network
network = Network(
    dynamics=dynamics,
    coupling={"delayed": coupling},
    graph=graph,
    noise=noise
)

# Create solver
solver = Heun()

# Run simulation
print("Running Network Dynamics...")
result = solve(
    network, solver,
    t0=0.0, t1=1000.0, dt=1.0
)

print(f"Shape: {result.ys.shape}")
Running Network Dynamics...
Shape: (1000, 1, 84)

Comparing the Results

Let’s visualize both simulations to verify they produce comparable dynamics:

Visualization Code
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 2, figsize=(8.1, 5.786), sharey=True)

# Select a few representative nodes
nodes_to_plot = [0, 20, 40]
colors = ['#3498db', '#e74c3c', '#2ecc71']

# Time series comparison
for idx, node in enumerate(nodes_to_plot):
    # TVB result
    axes[idx, 0].plot(time_tvb - 50, data_tvb[:, 0, node, 0],
                      color=colors[idx], linewidth=1, alpha=0.8)
    axes[idx, 0].set_ylabel(f'S (Node {node})\n{labels[node][:15]}...', fontsize=9)
    axes[idx, 0].grid(True, alpha=0.3)
    if idx == 0:
        axes[idx, 0].set_title('TVB Framework', fontweight='bold', fontsize=11)
    if idx == len(nodes_to_plot) - 1:
        axes[idx, 0].set_xlabel('Time [ms]')

    # Network Dynamics result
    axes[idx, 1].plot(result.ts, result.ys[:, 0, node],
                      color=colors[idx], linewidth=1, alpha=0.8)
    axes[idx, 1].set_ylabel(f'S (Node {node})\n{labels[node][:15]}...', fontsize=9)
    axes[idx, 1].grid(True, alpha=0.3)
    if idx == 0:
        axes[idx, 1].set_title('Network Dynamics Framework', fontweight='bold', fontsize=11)
    if idx == len(nodes_to_plot) - 1:
        axes[idx, 1].set_xlabel('Time [ms]')

plt.tight_layout()
plt.show()

The time series show qualitatively similar dynamics - both frameworks produce realistic resting-state activity with fluctuations around the fixed point, driven by noise and network interactions. The differences come from different noise realizations in numpy (TVB) and JAX (Network Dynamics).

ImportantNoise Parameter Difference

TVB and Network Dynamics use different noise parameterizations:

  • TVB: The nsig parameter in integrators.noise.Additive() is set as nsig = 0.5 * sigma^2, where sigma is the desired standard deviation. This is not the variance (which would be sigma^2), but a scaled version specific to TVB’s implementation.

  • Network Dynamics: The sigma parameter in AdditiveNoise() directly represents the standard deviation of the additive Gaussian noise.

To match noise levels between frameworks, use the conversion: nsig_tvb = 0.5 * sigma_network^2

Component Breakdown

The Network Dynamics framework consists of several modular components that work together. Each component implements a specific part of the mathematical structure described above.

Dynamics

Implements: \(f_d(S_i, \theta^d, C_i, I_i)\) - the local temporal evolution function

The dynamics define the differential equations at each network node, these are your neural mass models, oscillators, or other dynamical systems.

dynamics = ReducedWongWang(w=0.7)  # w is a parameter in θ^d

All TVB models are available out of the box and validated for Network Dynamics.

→ See Dynamics for details on models and creating custom dynamics

Graph (Connectivity)

Implements: \(A_{ij}\) and \(\tau_{ij}\) - the structural connectivity and transmission delays

The graph encodes the structural connectivity between nodes, which regions connect to which, with what strength, and with what delays.

graph = DenseDelayGraph(weights, lengths, region_labels=labels)

Supports both dense and sparse representations, with optional transmission delays computed from tract lengths.

→ See Graph for connectivity representations and empirical datasets

Coupling

Implements: \(f_c^{\text{pre}}\), \(f_c^{\text{post}}\), and the summation \(\sum_j A_{ij}\, f_c^{\text{pre}}(\cdots)\)

The coupling defines how nodes interact through the connectivity, this transforms states from connected nodes into input to the local dynamics.

coupling = DelayedLinearCoupling(incoming_states='S', G=0.5)  # G is part of θ^c

The coupling pattern mirrors TVB’s pre-sum-post architecture. You can use instantaneous or delayed coupling, linear or nonlinear transformations.

→ See Coupling for coupling types, delays, and custom implementations

Noise (Optional)

Implements: \(g(S_i, \theta^g)\) - the diffusion coefficient controlling noise intensity

The noise component adds stochastic fluctuations to the dynamics, transforming ODEs into SDEs.

noise = AdditiveNoise(sigma=0.01, key=jax.random.key(42))  # sigma is θ^g

Supports additive and multiplicative noise, with selective application to specific state variables. The Wiener process \(dW_i\) is automatically handled by the solver.

→ See Noise for stochastic processes and noise types

External Inputs (Optional)

Implements: \(I_i\) - external driving signals from outside the network

The external input system provides time-dependent (or state-dependent) driving signals from outside the network.

external_input = DataInput(times, data, interpolation='cubic')

Use parametric inputs (sine waves, pulses) or data-based inputs (interpolated recordings).

→ See External Inputs for stimulation and driving signals

Solvers

Implements: Numerical integration of \(dS_i = [\cdots]dt + g\,dW_i\)

The solver performs numerical integration of the network dynamics using various methods (Euler, Heun, Runge-Kutta).

solver = Heun()
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.5)

Native solvers (Euler, Heun, RK4) are optimized for brain networks and support all features. Diffrax solvers provide advanced methods for special cases.

→ See Solvers for integration methods and performance considerations

Design Philosophy

Network Dynamics is built on a few core principles that shape its architecture and user experience:

Pure JAX Foundation

Everything is JAX from the ground up:

  • Differentiable: Automatic gradients through entire simulations enable gradient-based parameter optimization
  • JIT-compiled: Automatic compilation to optimized machine code (CPU/GPU/TPU)
  • Vectorizable: vmap for efficient batch simulations across parameter sets or initial conditions
  • GPU-ready: Seamless acceleration on GPUs without code changes

Explicit Over Implicit

Every operation is visible and controllable:

  • No hidden state: All parameters, initial conditions, and configurations are explicitly provided
  • Clear data flow: Function signatures show exactly what goes in and what comes out
  • Transparent components: Each component (dynamics, coupling, noise) has a simple, documented interface
  • No magic: What you write is what executes, no automatic configuration or hidden transformations

Composability

Mix and match components freely:

  • Modular design: Swap dynamics models, coupling functions, or solvers without changing other parts, similar to TVB
  • Custom components: Subclass base classes to implement novel dynamics, coupling, or noise
  • Multiple couplings: Combine instantaneous and delayed, linear and nonlinear coupling in a single network
  • Flexible noise: Apply noise selectively to specific state variables

Optimization-First Design

Built for parameter inference and model fitting through a pure functional architecture:

The prepare() pattern: Under the hood, solve() calls prepare() which separates state from execution:

# High-level: convenient class-based API
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.5)

# Under the hood: prepare() returns a pure function and state
solve_fn, state = prepare(network, solver, t0=0.0, t1=1000.0, dt=0.5)
result = solve_fn(state)  # Pure function call with PyTree state

Why this enables optimization:

  • The Network, Dynamics, Coupling classes are convenient builders for complex parameter structures
  • prepare() compiles everything into a pure function and a PyTree of parameters
  • This pure function can be passed to JAX transformations: jit, grad, vmap
  • Direct gradient access: Use JAX’s grad() or value_and_grad() on solve_fn
  • The state is just nested dictionaries (Bunch objects), fully compatible with JAX’s PyTree system
  • Mark parameters for optimization with Parameter and BoundedParameter
  • Easy integration with Optax, BlackJAX, and other JAX optimization libraries
  • Efficiently explore parameter spaces with vmap for batch simulations
  • Inspect, modify, or serialize the state as plain data

This architecture gives you the convenience of high-level classes while maintaining JAX’s functional core for optimization.

Key Differences from TVB

If you’re coming from TVB, here are the main architectural differences:

Observations & Monitoring

TVB approach: Monitors sample the simulation during integration, selecting variables and applying transformations (downsampling, BOLD, etc.) in real-time.

Network Dynamics approach: Simulations return the full time series, and observations are applied as post-processing. TVB-compatible observation functions are available in tvboptim.observations.tvb_monitors:

from tvboptim.observations.tvb_monitors import Bold, SubSampling

# Run simulation - get full time series
result = solve(network, solver, t0=0.0, t1=10000.0, dt=1.0)

# Create BOLD monitor with standard parameters
bold_monitor = Bold(
    period=720.0,           # BOLD sampling period (1 TR = 720 ms)
    downsample_period=4.0,  # Intermediate downsampling matches dt
    voi=0,                  # Monitor first state variable (S)
)

# Create subsampling monitor with standard parameters
downsampling_monitor = SubSampling(
    period=10.0,  # Subsample period (10 ms)
)

# Apply observations as post-processing
bold_signal = bold_monitor(result)
downsampled = downsampling_monitor(result)

This makes the pipeline more transparent, you can inspect the raw simulation output, apply different observations to the same data, and compose observation functions.

No Modes Dimension

TVB: Output has shape [time, state, nodes, modes] where modes represent different oscillation modes or population types.

Network Dynamics: Output has shape [time, state, nodes]. If your model has multiple populations (e.g., excitatory/inhibitory), these are represented as separate state variables.

This simplifies indexing and makes it clearer what each dimension represents.

Noise Parameterization

As noted earlier, noise parameters differ:

  • TVB: nsig = 0.5 * sigma^2 (scaled parameter specific to TVB’s implementation)
  • Network Dynamics: sigma directly represents the standard deviation of Gaussian noise or sigma = $\sqrt{2\ \mathrm{nsig}}$

The Network Dynamics parameterization is more intuitive and matches standard SDE literature.

Next Steps

Start Exploring Components

Depending on your use case, dive into the detailed documentation:

  1. Want to understand dynamics models? → Start with Dynamics to see available models and create custom ones

  2. Working with custom connectivity? → Check Graph for dense/sparse representations and empirical datasets

  3. Need specialized coupling? → See Coupling for coupling types, delays, and performance considerations

  4. Adding stochasticity? → Explore Noise for additive and multiplicative processes

  5. Stimulating the network? → Read External Inputs for parametric and data-based stimulation

  6. Optimizing performance? → Study Solvers for native vs Diffrax methods

Complete Optimization Workflows

Ready to see the full power of Network Dynamics? These end-to-end tutorials demonstrate complete parameter optimization workflows, from setting up networks to fitting model parameters to empirical data using gradient-based optimization:

Reduced Wong-Wang FC Optimization(RWW)

Learn how to fit a whole-brain resting-state network model to empirical fMRI functional connectivity. This tutorial covers: - Setting up the RWW dynamics with structural connectivity - Defining loss functions for FC fitting - Exploring parameter sensitivity

Jansen-Rit MEG Peak Frequency Gradient Optimization (JR)

Reproduce the spatial frequency gradient observed in resting-state MEG data, where peak frequencies vary from ~7 Hz in association areas to ~11 Hz in sensory cortex. This tutorial demonstrates: - Modeling cortical columns with the Jansen-Rit neural mass model - Parameter exploration using grid search to map the frequency landscape - Defining region-specific optimization targets from neuroimaging data - Fitting heterogeneous (region-specific) parameters with gradient-based optimization - Spectral analysis and validation against empirical patterns

Get Help

  • Check the API Reference for detailed parameter descriptions
  • Look at examples in the repository
  • Report issues or ask questions on GitHub