solve

experimental.network_dynamics.solve

Solving system for network architecture.

This module provides the prepare-solve pattern for Network with multi-coupling support. The prepare() function sets up the integration with all coupling state management, and returns a pure function for execution.

Functions

Name Description
prepare Compile a model into a pure JAX solve function and a config PyTree.
solve Main entry point for simulation.

prepare

experimental.network_dynamics.solve.prepare(
    dynamics,
    solver,
    t0=0.0,
    t1=1.0,
    dt=0.1,
    n_nodes=1,
    noise=None,
    externals=None,
)

Compile a model into a pure JAX solve function and a config PyTree.

Builds per-dispatch data (coupling buffers, noise samples, external inputs) and returns (solve_fn, config) where solve_fn(config) runs the integration. Dispatches on the first two arguments via plum: Network/AbstractDynamics paired with NativeSolver/DiffraxSolver.

Parameters

Name Type Description Default
t0 float Integration interval and step size. dt is the fixed step for native solvers and the initial step for Diffrax. 0.0
t1 float Integration interval and step size. dt is the fixed step for native solvers and the initial step for Diffrax. 0.0
dt float Integration interval and step size. dt is the fixed step for native solvers and the initial step for Diffrax. 0.0

Returns

Name Type Description
(Callable, Bunch) Pure solve function and its runtime configuration PyTree.
See help(prepare) or prepare.__doc__ for the full reference,
including per-dispatch parameters (n_nodes, noise, externals
for bare dynamics) and Diffrax limitations (no delays, no auxiliaries,
no VOI filtering).

solve

experimental.network_dynamics.solve.solve(
    model,
    solver,
    t0=0.0,
    t1=100.0,
    dt=0.1,
    **kwargs,
)

Main entry point for simulation.

Accepts either a Network or a bare AbstractDynamics instance. Dispatches to the appropriate prepare() overload via plum.

Args: model: Network or AbstractDynamics instance solver: NativeSolver or DiffraxSolver instance t0: Start time t1: End time (inclusive for native solvers — see note on time grid) dt: Time step **kwargs: Additional arguments forwarded to prepare() (e.g. n_nodes for bare dynamics)

Returns: Simulation results wrapped in result object

Notes: Native solvers use the half-open scan grid arange(t0, t1, dt) and emit the post-step state on each iteration, so the returned save grid is (t0, t1]: result.ts = [t0 + dt, t0 + 2*dt, ..., t1], with the initial state at t0 excluded and the endpoint t1 included. The number of saved samples is (t1 - t0) / dt. t1 - t0 must be an integer multiple of dt for the grid to land exactly on t1.

Examples: >>> # With Network >>> result = solve(network, Euler(), t0=0, t1=10, dt=0.01)

>>> # With bare dynamics (single node)
>>> result = solve(JansenRit(), Heun(), t0=0, t1=1.0, dt=0.001)

>>> # With bare dynamics (multi-node uncoupled)
>>> result = solve(JansenRit(), Heun(), t0=0, t1=1.0, dt=0.001, n_nodes=3)