solve

experimental.network_dynamics.solve

Solving system for network architecture.

This module provides the prepare-solve pattern for Network with multi-coupling support. The prepare() function sets up the integration with all coupling state management, and returns a pure function for execution.

Functions

Name Description
prepare Prepare network dynamics model for simulation.
solve Main entry point for network simulation.

prepare

experimental.network_dynamics.solve.prepare(
    network,
    solver,
    t0=0.0,
    t1=1.0,
    dt=0.1,
)

Prepare network dynamics model for simulation.

Transforms a network dynamics model into a JAX-compiled simulation function and corresponding configuration object. Supports both native solvers (Euler, Heun) and Diffrax solvers with different feature sets and performance characteristics.

The preparation process optimizes the model for efficient execution by pre-compiling closures, pre-allocating buffers, and structuring data for JAX transformations.

Parameters

Name Type Description Default
network Network Network dynamics model containing: - dynamics : Neural mass/population model (e.g., ReducedWongWang, JansenRit) - couplings : Inter-region coupling functions (can be delayed or instantaneous) - graph : Connectivity structure (weights, delays, distances) - noise : Optional stochastic process (additive/multiplicative) - externals : Optional external inputs (e.g., stimulation) required
solver NativeSolver or DiffraxSolver Integration method. Two solver families available: NativeSolver (Euler, Heun): - Fixed time step integration - Supports all features: delays, noise, stateful operations - Optimized for jax.lax.scan - Best for most brain network simulations DiffraxSolver (Tsit5, Dopri5, etc.): - Adaptive time stepping - Stateless only: no delayed coupling, no history buffers - Useful for stiff ODEs or when adaptive stepping is required - Raises ValueError if network has delays required
t0 float Simulation start time, by default 0.0 0.0
t1 float Simulation end time, by default 1.0 1.0
dt float Integration time step, by default 0.1 - For NativeSolver: Fixed step size used throughout simulation - For DiffraxSolver: Initial step size (dt0) for adaptive controller 0.1

Returns

Name Type Description
solve_function Callable Pure JAX function for running simulation. Signature: solve_function(config) -> results The function is JIT-compiled and supports: - Automatic differentiation (jax.grad, jax.jacobian) - Vectorization (jax.vmap) - Parallel execution (jax.pmap)
config Bunch Configuration PyTree containing: - dynamics : Dynamics model parameters - coupling : Coupling parameters (one entry per coupling) - external : External input parameters (one entry per input) - noise : Noise parameters (if stochastic) - graph : Graph structure (weights, delays) - initial_state : Initial conditions [n_states, n_nodes] - **_internal** : Precomputed data (coupling indices, noise samples, etc.)

Raises

Name Type Description
ValueError If using DiffraxSolver with delayed coupling (network.max_delay > 0). Diffrax solvers cannot maintain history buffers due to internal loop control.

Examples

Basic Usage with Native Solver

>>> from tvboptim.experimental.network_dynamics import Network, prepare
>>> from tvboptim.experimental.network_dynamics.dynamics import ReducedWongWang
>>> from tvboptim.experimental.network_dynamics.solvers import Euler
>>> from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
>>> from tvboptim.experimental.network_dynamics.graph import DenseGraph
>>> import jax.numpy as jnp
>>>
>>> # Create network components
>>> dynamics = ReducedWongWang()
>>> coupling = LinearCoupling(incoming_states='S', G=1.0)
>>> weights = jnp.ones((68, 68))  # 68 brain regions
>>> graph = DenseGraph(weights)
>>>
>>> # Build network
>>> network = Network(dynamics, coupling, graph)
>>>
>>> # Prepare for simulation
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)
>>>
>>> # Run simulation
>>> results = model_fn(config)
>>> print(results.data.shape)  # [n_timesteps, n_voi, n_nodes]

With Delayed Coupling (Native Solver Only)

>>> from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling
>>> from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
>>>
>>> # Create graph with heterogeneous delays
>>> delays = jnp.array([...])  # [n_nodes, n_nodes] delay matrix in ms
>>> graph = DenseDelayGraph(weights, delays)
>>>
>>> # Delayed coupling requires history buffer
>>> coupling = DelayedLinearCoupling(incoming_states='S', G=2.0)
>>> network = Network(dynamics, coupling, graph)
>>>
>>> # Only NativeSolver supports delays
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)

With Adaptive Stepping (Diffrax Solver)

>>> from tvboptim.experimental.network_dynamics.solvers import DiffraxSolver
>>> import diffrax
>>>
>>> # Diffrax solver with adaptive time stepping
>>> solver = DiffraxSolver(
...     diffrax.Tsit5(),
...     saveat=diffrax.SaveAt(ts=jnp.arange(0, 100, 0.1))
... )
>>>
>>> # Network must NOT have delays for Diffrax
>>> network = Network(dynamics, LinearCoupling(...), graph)
>>> model_fn, config = prepare(network, solver, t0=0, t1=100, dt=0.1)
>>> solution = model_fn(config)  # Returns diffrax.Solution object

With Stochastic Dynamics

>>> from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
>>> import jax
>>>
>>> # Add noise to network
>>> noise = AdditiveNoise(state_indices=[0], sigma=0.01, key=jax.random.PRNGKey(0))
>>> network = Network(dynamics, coupling, graph, noise=noise)
>>>
>>> # Prepare with noise (pre-generates noise samples)
>>> model_fn, config = prepare(network, Euler(), t0=0, t1=100, dt=0.1)

Modifying Parameters

>>> # Config is a PyTree - parameters can be modified
>>> import copy
>>> config_modified = copy.deepcopy(config)
>>> config_modified.dynamics.G = 2.5  # Change global coupling
>>> config_modified.coupling.default.G = 1.5  # Change coupling strength
>>>
>>> # Run with modified parameters
>>> results_modified = model_fn(config_modified)

Notes

Preparation Steps (NativeSolver):

  1. Prepare all couplings (create history buffers for delays if needed)
  2. Build config structure with flattened parameters and graph
  3. Pre-generate noise samples if stochastic (one sample per timestep)
  4. Pre-compile coupling computation closures (avoid dict lookups in scan)
  5. Pre-compile state update closures (for history buffer management)
  6. Return pure function optimized for jax.lax.scan

Preparation Steps (DiffraxSolver):

  1. Validate network has no delays (raises ValueError if found)
  2. Prepare stateless coupling/external input data
  3. Build config with parameters and precomputed data
  4. Create Diffrax vector field and control term (for SDEs)
  5. Return pure function wrapping diffrax.diffeqsolve

Solver Selection Guidelines:

Use NativeSolver (Euler, Heun) when:

  • Network has delayed coupling
  • Need full control over integration loop
  • Want optimal performance with jax.lax.scan
  • Standard brain network simulation

Use DiffraxSolver when:

  • Network has no delays (stateless)
  • Need adaptive time stepping for stiff systems
  • Want access to advanced Diffrax features
  • Require error control and step size adaptation

Performance Notes:

  • Native solvers use jax.lax.scan for optimal compile-time optimization
  • Pre-compilation of closures eliminates runtime overhead
  • History buffers for delays use efficient circular indexing
  • Noise samples are pre-generated to avoid per-step RNG calls

See Also

solve : High-level interface that calls prepare() and executes immediately Network : Network dynamics model container NativeSolver : Fixed-step integration methods (Euler, Heun) DiffraxSolver : Adaptive-step integration using Diffrax library

solve

experimental.network_dynamics.solve.solve(
    network,
    solver,
    t0=0.0,
    t1=100.0,
    dt=0.1,
)

Main entry point for network simulation.

Args: network: Network instance with multi-coupling support solver: NativeSolver instance (Euler, Heun, etc.) t0: Start time t1: End time dt: Time step

Returns: Simulation results wrapped in result object

Example: >>> from network_dynamics import Network, solve >>> from network_dynamics.solvers import Euler >>> from network_dynamics.dynamics import Lorenz >>> from network_dynamics.coupling import LinearCoupling >>> from network_dynamics.graph import Graph >>> >>> dynamics = Lorenz() >>> coupling = LinearCoupling(incoming_states=‘x’, G=1.0) >>> graph = Graph(weights) >>> network = Network(dynamics, coupling, graph) >>> >>> result = solve(network, Euler(), t0=0, t1=10, dt=0.01)