tvbo.prepare

tvbo.prepare(
    experiment,
    t0=0.0,
    t1=100.0,
    dt=0.1,
    enable_x64=True,
    replace_temporal_averaging=False,
    return_new_ics=False,
    scalar_pre=False,
    bold_fft_convolve=True,
    small_dt=False,
    **kwargs,
)

Convert TVBO SimulationExperiment to JAX-compatible model function and state.

This function transforms a TVBO simulation experiment into a JAX-compiled model function and corresponding state object for efficient brain simulation. The resulting model supports automatic differentiation and parallel execution.

Parameters

Name Type Description Default
experiment tvbo.export.experiment.SimulationExperiment TVBO SimulationExperiment containing model, connectivity, coupling, integration, and monitor specifications. required
t0 float Start time for simulation. Default is 0.0. Note: Currently not used, reserved for future integration. 0.0
t1 float End time for simulation. Default is 100.0. Note: Currently not used, reserved for future integration. 100.0
dt float Time step for simulation. Default is 0.1. Note: Currently not used, reserved for future integration. 0.1
enable_x64 bool If True, use float64 precision; otherwise float32. Transforms all arrays in state to correct precision and sets JAX config ‘jax_enable_x64’. Default is True. True
replace_temporal_averaging bool If False, BOLD uses TemporalAverage monitor as TVB does. If True, uses faster SubSample monitor with similar results. Default is False. False
return_new_ics bool If True, model returns updated initial conditions TimeSeries along with simulation output for continuing simulations. Changes output from result to [result, initial_conditions]. Default is False. False
scalar_pre bool If True, applies performance optimization replacing dot product with matmul in coupling term. Only works with scalar-only pre expressions, no delays, and when pre expression has single x_j occurrence. Default is False. False
bold_fft_convolve bool If True, BOLD monitor uses FFT convolution instead of dot product. Faster for most cases, time doesn’t scale with BOLD period. Dot product can be faster for large period values. Default is True. True
small_dt bool Uses full history storage for faster simulations at small dt. Can cause memory explosion under jax.grad transformation. Default is False. False
**kwargs dict Additional keyword arguments passed to experiment.execute(). {}

Returns

Name Type Description
tuple[Callable, Any] A tuple containing (model_function, state) where: - model_function : Callable that takes state and returns simulation results - state : JAX PyTree containing all simulation parameters and initial conditions

Examples

>>> from tvbo.export.experiment import SimulationExperiment
>>> from tvboptim import prepare
>>>
>>> # Create TVBO experiment
>>> experiment = SimulationExperiment(...)
>>>
>>> # Convert to JAX
>>> model, state = prepare(experiment, enable_x64=True, scalar_pre=True)
>>>
>>> # Run simulation
>>> result = model(state)
>>> raw_data, bold_data = result
>>>
>>> # Use with JAX transformations
>>> grad_fn = jax.grad(lambda s: model(s)[0].data.sum())
>>> gradients = grad_fn(state)

Notes

The returned model function is JAX-compiled and supports:

  • Automatic differentiation with jax.grad, jax.jacobian
  • Parallel execution with jax.vmap, jax.pmap
  • Just-in-time compilation for optimal performance
  • Integration with JAX ecosystem (optax, equinox, etc.)

The state object is a JAX PyTree that can be used with all JAX transformations and contains Parameter objects for optimization workflows.