jaxify

jaxify(experiment, enable_x64=True, **kwargs)

Convert TVBO SimulationExperiment to JAX-compatible model function and a state.

Parameters

Name Type Description Default
experiment TVBO SimulationExperiment TVBO SimulationExperiment to convert. required
enable_x64 bool If True, use float64 precision; otherwise float32. Transforms all arrays in state to correct precision and set jax config jax_enable_x64. Default is True. True
replace_temporal_averaging bool If False, BOLD uses TemporalAverage monitor as TVB does. If True, uses faster SubSample monitor with similar results. required
return_new_ics bool If True, model returns an updated initial conditions TimeSeries along with simulation output for continuing simulations. Changes output from result to [result, initial_conditions]. required
scalar_pre bool If True, applies performance optimization replacing dot product with matmul in coupling term. Only works with scalar-only pre expressions, no delays, and when pre expression has single x_j occurrence. required
bold_fft_convolve bool If True, BOLD monitor uses FFT convolution instead of dot product. Faster for most cases, time doesn’t scale with BOLD period. Dot product can be faster for large period values. required
small_dt bool Uses full history storage for faster simulations at small dt. Can cause memory explosion under jax.grad transformation. required
**kwargs dict Additional keyword arguments passed to downstream functions. {}

Returns

Name Type Description
tuple A tuple containing (model_function, state_collection) where model_function takes state_collection and returns simulation results. Usage: result = model_function(state_collection)