tvboptim Backend
JAX-based optimization and parameter fitting for brain network models
Overview
The tvboptim backend provides high-performance simulation and optimization capabilities using JAX. It enables:
- JIT-compiled simulations for fast execution
- Automatic differentiation for gradient-based optimization
- GPU acceleration when available
- Declarative YAML specification of optimization workflows
Workflow Examples
| Workflow | Model | Description |
|---|---|---|
| Jansen-Rit | Jansen-Rit | MEG frequency optimization |
| Reduced Wong-Wang | RWW | Functional connectivity fitting |
| EI Tuning | RWW | Excitation-Inhibition balance tuning |
Quick Start
from tvbo import SimulationExperiment
# Load experiment from YAML
exp = SimulationExperiment.from_file("database/studies/JR_MEG_FrequencyGradient_Optimization.yaml")
# Run optimization
results = exp.run()
# Access results
print(results)Result Structure
Results follow a declarative access pattern mirroring the YAML structure:
ExperimentResult:
├── integration
│ └── main
│ ├── data: Array(duration, regions, state_vars)
│ └── observations
│ ├── bold: Array(...)
│ └── fc: Array(regions, regions)
├── algorithms
│ └── fic
│ ├── state: Final tuned parameters
│ └── history: Per-iteration tracking
└── optimization
└── global_optimization
├── state: Optimized parameters
├── loss: Final loss value
├── history: Loss curve, parameters
└── simulation
├── data: Post-optimization simulation
└── observations: BOLD, FC, etc.
Key Features
Automatic Differentiation
tvboptim uses JAX’s autodiff to compute gradients through entire simulations:
optimization:
global_optimization:
parameters:
- dynamics.G
loss_function: loss_fc
optimizer:
name: Adam
learning_rate: 0.01
max_steps: 100Smart Interval Defaults
Print and save intervals are automatically set based on iteration count: - 1-10 iterations → every 1 - 10-100 iterations → every 10 - 100-1000 iterations → every 100
Multi-Stage Optimization
Chain multiple optimization stages with different objectives:
optimization:
stage1:
parameters: [dynamics.G]
loss_function: loss_fc
max_steps: 100
stage2:
parameters: [dynamics.J_i]
loss_function: loss_bold
max_steps: 50State Trajectory
Track parameter evolution during optimization:
trajectory = results.optimization.global_optimization.state_trajectory
for state in trajectory:
print(f"G={state.dynamics.G}")See Also
- Jansen-Rit Workflow: Complete MEG optimization example
- RWW Workflow: Functional connectivity fitting
- EI Tuning: E/I balance optimization