---
title: "Solvers"
format:
html:
code-fold: false
fig-width: 8
out-width: "100%"
jupyter: python3
execute:
cache: true
---
## The *solve()* Function
The `solve()` function is the main entry point for running network simulations. It takes a network configuration and a solver, then returns the time-evolved state.
**Basic signature:**
```python
from tvboptim.experimental.network_dynamics import solve
result = solve(
network, # Network configuration (dynamics, coupling, graph)
solver, # Solver instance (Native or Diffrax)
t0=0.0, # Start time [ms]
t1=1000.0, # End time [ms]
dt=0.1 # Integration time step [ms]
)
```
**Key parameters:**
- `network`: A `Network` object containing dynamics, coupling, graph, and optionally noise
- `solver`: A solver instance that determines the integration method
- `t0`, `t1`: Time interval to simulate (in milliseconds)
- `dt`: Integration time step (in milliseconds)
**Return value:**
The `solve()` function returns a solution object with two main attributes:
- `result.ts`: Time points array `[n_time]`
- `result.ys`: State trajectories `[n_time, n_states, n_nodes]`
---
## Native Solvers (Recommended)
Native solvers are JAX-based implementations optimized for brain network simulations. **These are the go-to solvers** for most use cases.
### Features
- Auxiliary variables (intermediate computations)
- Variables of Interest (VOI) for selective output
- Stateful couplings (delayed connections, history buffers)
- Both deterministic (ODE) and stochastic (SDE) integration
- Optimized for large-scale brain networks
### Available Methods
```python
from tvboptim.experimental.network_dynamics.solvers import Euler, Heun, RungeKutta4
# Euler method (1st order)
solver = Euler()
# Heun method (2nd order, predictor-corrector)
solver = Heun()
# Runge-Kutta 4th order (higher accuracy)
solver = RungeKutta4()
```
### Basic Example
```{python}
#| eval: true
import jax.numpy as jnp
from tvboptim.experimental.network_dynamics import Network, solve
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseGraph
from tvboptim.experimental.network_dynamics.solvers import Heun
# Create network
weights = jnp.array([[0.0, 1.0], [1.0, 0.0]])
graph = DenseGraph(weights)
dynamics = ReducedWongWang()
coupling = LinearCoupling(incoming_states="S", G=2.0)
network = Network(
dynamics=dynamics,
coupling={'instant': coupling},
graph=graph
)
# Use native Heun solver
solver = Heun()
# Run simulation
result = solve(network, solver, t0=0.0, t1=100.0, dt=0.1)
print(f"Result shape: {result.ys.shape}") # [n_time, n_states, n_nodes]
print(f"Time range: [{result.ts[0]:.1f}, {result.ts[-1]:.1f}] ms")
```
### Variables of Interest and Auxiliary Variables
You can control which variables to save in the output using `VARIABLES_OF_INTEREST`. This is useful for:
1. **Reducing memory usage** - Save only the states you need
2. **Including auxiliary variables** - Some dynamics models compute intermediate values (auxiliary variables). These are **only saved if explicitly included in `VARIABLES_OF_INTEREST`**.
Let's demonstrate with the ReducedWongWang model, which computes the firing rate 'H' as an auxiliary variable:
```{python}
#| eval: true
# ReducedWongWang has state variable 'S' and auxiliary variable 'H' (firing rate)
dynamics_default = ReducedWongWang()
print(f"State variables: {dynamics_default.STATE_NAMES}")
print(f"Auxiliary variables: {dynamics_default.AUXILIARY_NAMES}")
# By default, only state variables are saved
network_default = Network(
dynamics=dynamics_default,
coupling={'instant': coupling},
graph=graph
)
result_default = solve(network_default, Heun(), t0=0.0, t1=100.0, dt=0.1)
print(f"\nDefault - saved variables: {result_default.ys.shape[1]}") # Only state variables (1)
# Now include the auxiliary variable 'H' (firing rate) in VOI
dynamics_with_H = ReducedWongWang(
VARIABLES_OF_INTEREST=('S', 'H') # Include firing rate auxiliary variable
)
network_with_H = Network(
dynamics=dynamics_with_H,
coupling={'instant': coupling},
graph=graph
)
result_with_H = solve(network_with_H, Heun(), t0=0.0, t1=100.0, dt=0.1)
print(f"With auxiliary 'H' - saved variables: {result_with_H.ys.shape[1]}") # State + aux (2)
print(f"\nShape changed from {result_default.ys.shape[1]} to {result_with_H.ys.shape[1]} by including 'H' (firing rate) in VOI")
```
### Stochastic Simulations (SDE)
Native solvers support stochastic differential equations by adding noise:
```{python}
#| eval: true
import jax
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
# Add Gaussian white noise
noise = AdditiveNoise(sigma=0.01, key=jax.random.key(42))
network_sde = Network(
dynamics=dynamics,
coupling={'instant': coupling},
graph=graph,
noise=noise # Enable stochastic integration
)
# Use Heun for SDE (becomes Heun-Maruyama method)
result_sde = solve(network_sde, Heun(), t0=0.0, t1=100.0, dt=0.1)
print(f"SDE result shape: {result_sde.ys.shape}")
```
---
## Diffrax Solvers (Advanced)
For advanced use cases, TVB-Optim supports [Diffrax](https://docs.kidger.site/diffrax/) - a powerful JAX library for differential equations with features like implicit methods, adaptive stepping, and advanced controllers.
### When to Use Diffrax
**Good for:**
- Stiff ODEs requiring implicit methods
- Adaptive time stepping
- Advanced stepsize control
- Specialized solver algorithms
**Limitations:**
- **No stateful couplings** (no delayed connections, no history buffers)
- **No auxiliary variable tracking**
- Solution arrays may be padded with `inf` when using `max_steps`
::: {.callout-important}
Diffrax solvers **do not support delayed coupling** or any coupling that requires `update_state()`. Only instantaneous (stateless) couplings are supported.
:::
### Available Methods
Diffrax provides [many solvers](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/). Common choices:
```python
from tvboptim.experimental.network_dynamics.solvers import DiffraxSolver
import diffrax
# Explicit methods (for non-stiff ODEs)
solver = DiffraxSolver(solver=diffrax.Dopri5()) # Adaptive RK45
solver = DiffraxSolver(solver=diffrax.Euler()) # Euler
solver = DiffraxSolver(solver=diffrax.Heun()) # Heun
# Implicit methods (for stiff ODEs)
solver = DiffraxSolver(solver=diffrax.Kvaerno5()) # Implicit RK
solver = DiffraxSolver(solver=diffrax.ImplicitEuler()) # Backward Euler
# SDE methods
solver = DiffraxSolver(solver=diffrax.EulerHeun()) # SDE-compatible
```
### Basic Example with Implicit Solver
For stiff systems (e.g., fast inhibitory dynamics), implicit methods can be more stable:
```{python}
#| eval: false
import diffrax
from tvboptim.experimental.network_dynamics.solvers import DiffraxSolver
# Create network with stiff dynamics
dynamics = JansenRit(b=0.5) # Fast inhibitory time constant
network = Network(
dynamics=dynamics,
coupling={'instant': coupling_instant}, # Only instantaneous coupling!
graph=graph
)
# Use implicit solver for stiff ODE
solver = DiffraxSolver(
solver=diffrax.Kvaerno5(), # 5th order implicit Runge-Kutta
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1000, 10000))
)
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.1)
```
### Adaptive Stepping
Diffrax excels at adaptive time stepping for efficient integration:
```{python}
#| eval: false
# Adaptive stepping with error control
solver = DiffraxSolver(
solver=diffrax.Dopri5(), # Dormand-Prince 5(4) adaptive method
stepsize_controller=diffrax.PIDController(
rtol=1e-3, # Relative tolerance
atol=1e-6, # Absolute tolerance
dtmin=0.01, # Minimum step size
dtmax=1.0 # Maximum step size
),
saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1000, 1000)) # Output points
)
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.1)
```
### Handling Inf-Padded Arrays
When using `max_steps`, Diffrax may pad solution arrays with `inf` values:
```{python}
#| eval: false
# Approach 1: Use explicit saveat (recommended)
solver = DiffraxSolver(
solver=diffrax.Euler(),
saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1000, 10000)) # Explicit times
)
# Approach 2: Filter finite values in post-processing
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.1)
# Filter out inf-padded entries
finite_mask = jnp.isfinite(result.ts)
ts_clean = result.ts[finite_mask]
ys_clean = result.ys[finite_mask]
```
### SDE with Diffrax
Diffrax supports stochastic integration with Brownian motion:
```{python}
#| eval: false
# Create stochastic network
noise = AdditiveNoise(sigma=0.01, key=jax.random.key(42))
network = Network(dynamics=dynamics, coupling=coupling, graph=graph, noise=noise)
# Use SDE-compatible solver
solver = DiffraxSolver(
solver=diffrax.Heun(), # SDE-compatible (Stratonovich)
stepsize_controller=diffrax.ConstantStepSize(),
saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1000, 10000))
)
result = solve(network, solver, t0=0.0, t1=1000.0, dt=0.1)
```
## Solver Comparison & Recommendations
| Feature | Native Solvers | Diffrax Solvers |
|---------|---------------|-----------------|
| **Stateful coupling** (delays) | ✓ Yes | ✗ No |
| **Auxiliary variables** | ✓ Yes | ✗ No |
| **Variables of Interest** | ✓ Yes | - |
| **ODE support** | ✓ Yes | ✓ Yes |
| **SDE support** | ✓ Yes | ✓ Yes |
| **Implicit methods** | ✗ No | ✓ Yes |
| **Adaptive stepping** | ✗ No | ✓ Yes |
| **Large-scale networks** | ✓ Optimized | ✓ Good |
| **Ease of use** | ✓ Simple | Advanced |
---
**Use Native solvers when:**
- You have delayed couplings or need stateful connections
- You need auxiliary variables for monitoring
- You want a simple, reliable solution
- You're working with standard brain network models
**Use Diffrax solvers when:**
- You have stiff ODEs requiring implicit methods
- You need adaptive time stepping
- You want advanced error control
- Your couplings are instantaneous only
**For most brain network simulations, start with Native solvers (Heun or RK4).**
---
## Further Reading
- [Diffrax Documentation](https://docs.kidger.site/diffrax/)
- [Diffrax Solver API](https://docs.kidger.site/diffrax/api/solvers/ode_solvers/)
- [Diffrax Examples](https://docs.kidger.site/diffrax/examples/basic_ode/)