prepare
experimental.network_dynamics.prepare(network, solver, t0=0.0, t1=1.0, dt=0.1)Prepare network dynamics model for simulation.
Transforms a network dynamics model into a JAX-compiled simulation function and corresponding configuration object. Supports both native solvers (Euler, Heun) and Diffrax solvers with different feature sets and performance characteristics.
The preparation process optimizes the model for efficient execution by pre-compiling closures, pre-allocating buffers, and structuring data for JAX transformations.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| network | Network | Network dynamics model containing: - dynamics : Neural mass/population model (e.g., ReducedWongWang, JansenRit) - couplings : Inter-region coupling functions (can be delayed or instantaneous) - graph : Connectivity structure (weights, delays, distances) - noise : Optional stochastic process (additive/multiplicative) - externals : Optional external inputs (e.g., stimulation) | required |
| solver | NativeSolver or DiffraxSolver | Integration method. Two solver families available: NativeSolver (Euler, Heun): - Fixed time step integration - Supports all features: delays, noise, stateful operations - Optimized for jax.lax.scan - Best for most brain network simulations DiffraxSolver (Tsit5, Dopri5, etc.): - Adaptive time stepping - Stateless only: no delayed coupling, no history buffers - Useful for stiff ODEs or when adaptive stepping is required - Raises ValueError if network has delays | required |
| t0 | float | Simulation start time, by default 0.0 | 0.0 |
| t1 | float | Simulation end time, by default 1.0 | 1.0 |
| dt | float | Integration time step, by default 0.1 - For NativeSolver: Fixed step size used throughout simulation - For DiffraxSolver: Initial step size (dt0) for adaptive controller | 0.1 |
Returns
| Name | Type | Description |
|---|---|---|
| solve_function | Callable | Pure JAX function for running simulation. Signature: solve_function(config) -> results The function is JIT-compiled and supports: - Automatic differentiation (jax.grad, jax.jacobian) - Vectorization (jax.vmap) - Parallel execution (jax.pmap) |
| config | Bunch | Configuration PyTree containing: - dynamics : Dynamics model parameters - coupling : Coupling parameters (one entry per coupling) - external : External input parameters (one entry per input) - noise : Noise parameters (if stochastic) - graph : Graph structure (weights, delays) - initial_state : Initial conditions [n_states, n_nodes] - **_internal** : Precomputed data (coupling indices, noise samples, etc.) |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If using DiffraxSolver with delayed coupling (network.max_delay > 0). Diffrax solvers cannot maintain history buffers due to internal loop control. |
Examples
Basic Usage with Native Solver
>>> from tvboptim.experimental.network_dynamics import Network, prepare
>>> from tvboptim.experimental.network_dynamics.dynamics import ReducedWongWang
>>> from tvboptim.experimental.network_dynamics.solvers import Euler
>>> from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
>>> from tvboptim.experimental.network_dynamics.graph import DenseGraph
>>> import jax.numpy as jnp
>>>
>>> # Create network components
>>> dynamics = ReducedWongWang()
>>> coupling = LinearCoupling(incoming_states='S', G=1.0)
>>> weights = jnp.ones((68, 68)) # 68 brain regions
>>> graph = DenseGraph(weights)
>>>
>>> # Build network
>>> network = Network(dynamics, coupling, graph)
>>>
>>> # Prepare for simulation
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)
>>>
>>> # Run simulation
>>> results = model_fn(config)
>>> print(results.data.shape) # [n_timesteps, n_voi, n_nodes]With Delayed Coupling (Native Solver Only)
>>> from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling
>>> from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
>>>
>>> # Create graph with heterogeneous delays
>>> delays = jnp.array([...]) # [n_nodes, n_nodes] delay matrix in ms
>>> graph = DenseDelayGraph(weights, delays)
>>>
>>> # Delayed coupling requires history buffer
>>> coupling = DelayedLinearCoupling(incoming_states='S', G=2.0)
>>> network = Network(dynamics, coupling, graph)
>>>
>>> # Only NativeSolver supports delays
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)With Adaptive Stepping (Diffrax Solver)
>>> from tvboptim.experimental.network_dynamics.solvers import DiffraxSolver
>>> import diffrax
>>>
>>> # Diffrax solver with adaptive time stepping
>>> solver = DiffraxSolver(
... diffrax.Tsit5(),
... saveat=diffrax.SaveAt(ts=jnp.arange(0, 100, 0.1))
... )
>>>
>>> # Network must NOT have delays for Diffrax
>>> network = Network(dynamics, LinearCoupling(...), graph)
>>> model_fn, config = prepare(network, solver, t0=0, t1=100, dt=0.1)
>>> solution = model_fn(config) # Returns diffrax.Solution objectWith Stochastic Dynamics
>>> from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
>>> import jax
>>>
>>> # Add noise to network
>>> noise = AdditiveNoise(state_indices=[0], sigma=0.01, key=jax.random.PRNGKey(0))
>>> network = Network(dynamics, coupling, graph, noise=noise)
>>>
>>> # Prepare with noise (pre-generates noise samples)
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)Modifying Parameters
>>> # Config is a PyTree - parameters can be modified
>>> import copy
>>> config_modified = copy.deepcopy(config)
>>> config_modified.dynamics.G = 2.5 # Change global coupling
>>> config_modified.coupling.default.G = 1.5 # Change coupling strength
>>>
>>> # Run with modified parameters
>>> results_modified = model_fn(config_modified)Notes
Preparation Steps (NativeSolver):
- Prepare all couplings (create history buffers for delays if needed)
- Build config structure with flattened parameters and graph
- Pre-generate noise samples if stochastic (one sample per timestep)
- Pre-compile coupling computation closures (avoid dict lookups in scan)
- Pre-compile state update closures (for history buffer management)
- Return pure function optimized for jax.lax.scan
Preparation Steps (DiffraxSolver):
- Validate network has no delays (raises ValueError if found)
- Prepare stateless coupling/external input data
- Build config with parameters and precomputed data
- Create Diffrax vector field and control term (for SDEs)
- Return pure function wrapping diffrax.diffeqsolve
Solver Selection Guidelines:
Use NativeSolver (Euler, Heun) when:
- Network has delayed coupling
- Need full control over integration loop
- Want optimal performance with jax.lax.scan
- Standard brain network simulation
Use DiffraxSolver when:
- Network has no delays (stateless)
- Need adaptive time stepping for stiff systems
- Want access to advanced Diffrax features
- Require error control and step size adaptation
Performance Notes:
- Native solvers use jax.lax.scan for optimal compile-time optimization
- Pre-compilation of closures eliminates runtime overhead
- History buffers for delays use efficient circular indexing
- Noise samples are pre-generated to avoid per-step RNG calls
See Also
solve : High-level interface that calls prepare() and executes immediately Network : Network dynamics model container NativeSolver : Fixed-step integration methods (Euler, Heun) DiffraxSolver : Adaptive-step integration using Diffrax library