tvbo.prepare
tvbo.prepare(
experiment,
t0=0.0,
t1=100.0,
dt=0.1,
enable_x64=True,
replace_temporal_averaging=False,
return_new_ics=False,
scalar_pre=False,
bold_fft_convolve=True,
small_dt=False,
**kwargs,
)Convert TVBO SimulationExperiment to JAX-compatible model function and state.
This function transforms a TVBO simulation experiment into a JAX-compiled model function and corresponding state object for efficient brain simulation. The resulting model supports automatic differentiation and parallel execution.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| experiment | tvbo.export.experiment.SimulationExperiment | TVBO SimulationExperiment containing model, connectivity, coupling, integration, and monitor specifications. | required |
| t0 | float | Start time for simulation. Default is 0.0. Note: Currently not used, reserved for future integration. | 0.0 |
| t1 | float | End time for simulation. Default is 100.0. Note: Currently not used, reserved for future integration. | 100.0 |
| dt | float | Time step for simulation. Default is 0.1. Note: Currently not used, reserved for future integration. | 0.1 |
| enable_x64 | bool | If True, use float64 precision; otherwise float32. Transforms all arrays in state to correct precision and sets JAX config ‘jax_enable_x64’. Default is True. | True |
| replace_temporal_averaging | bool | If False, BOLD uses TemporalAverage monitor as TVB does. If True, uses faster SubSample monitor with similar results. Default is False. | False |
| return_new_ics | bool | If True, model returns updated initial conditions TimeSeries along with simulation output for continuing simulations. Changes output from result to [result, initial_conditions]. Default is False. | False |
| scalar_pre | bool | If True, applies performance optimization replacing dot product with matmul in coupling term. Only works with scalar-only pre expressions, no delays, and when pre expression has single x_j occurrence. Default is False. | False |
| bold_fft_convolve | bool | If True, BOLD monitor uses FFT convolution instead of dot product. Faster for most cases, time doesn’t scale with BOLD period. Dot product can be faster for large period values. Default is True. | True |
| small_dt | bool | Uses full history storage for faster simulations at small dt. Can cause memory explosion under jax.grad transformation. Default is False. | False |
| **kwargs | dict | Additional keyword arguments passed to experiment.execute(). | {} |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[Callable, Any] | A tuple containing (model_function, state) where: - model_function : Callable that takes state and returns simulation results - state : JAX PyTree containing all simulation parameters and initial conditions |
Examples
>>> from tvbo.export.experiment import SimulationExperiment
>>> from tvboptim import prepare
>>>
>>> # Create TVBO experiment
>>> experiment = SimulationExperiment(...)
>>>
>>> # Convert to JAX
>>> model, state = prepare(experiment, enable_x64=True, scalar_pre=True)
>>>
>>> # Run simulation
>>> result = model(state)
>>> raw_data, bold_data = result
>>>
>>> # Use with JAX transformations
>>> grad_fn = jax.grad(lambda s: model(s)[0].data.sum())
>>> gradients = grad_fn(state)Notes
The returned model function is JAX-compiled and supports:
- Automatic differentiation with jax.grad, jax.jacobian
- Parallel execution with jax.vmap, jax.pmap
- Just-in-time compilation for optimal performance
- Integration with JAX ecosystem (optax, equinox, etc.)
The state object is a JAX PyTree that can be used with all JAX transformations and contains Parameter objects for optimization workflows.