tvboptim Backend

JAX-based optimization and parameter fitting for brain network models

Overview

The tvboptim backend provides high-performance simulation and optimization capabilities using JAX. It enables:

  • JIT-compiled simulations for fast execution
  • Automatic differentiation for gradient-based optimization
  • GPU acceleration when available
  • Declarative YAML specification of optimization workflows

Workflow Examples

Workflow Model Description
Jansen-Rit Jansen-Rit MEG frequency optimization
Reduced Wong-Wang RWW Functional connectivity fitting
EI Tuning RWW Excitation-Inhibition balance tuning

Quick Start

from tvbo import SimulationExperiment

# Load experiment from YAML
exp = SimulationExperiment.from_file("database/studies/JR_MEG_FrequencyGradient_Optimization.yaml")

# Run optimization
results = exp.run()

# Access results
print(results)

Result Structure

Results follow a declarative access pattern mirroring the YAML structure:

ExperimentResult:
├── integration
│   └── main
│       ├── data: Array(duration, regions, state_vars)
│       └── observations
│           ├── bold: Array(...)
│           └── fc: Array(regions, regions)
├── algorithms
│   └── fic
│       ├── state: Final tuned parameters
│       └── history: Per-iteration tracking
└── optimization
    └── global_optimization
        ├── state: Optimized parameters
        ├── loss: Final loss value
        ├── history: Loss curve, parameters
        └── simulation
            ├── data: Post-optimization simulation
            └── observations: BOLD, FC, etc.

Key Features

Automatic Differentiation

tvboptim uses JAX’s autodiff to compute gradients through entire simulations:

optimization:
  global_optimization:
    parameters:
      - dynamics.G
    loss_function: loss_fc
    optimizer:
      name: Adam
      learning_rate: 0.01
    max_steps: 100

Smart Interval Defaults

Print and save intervals are automatically set based on iteration count: - 1-10 iterations → every 1 - 10-100 iterations → every 10 - 100-1000 iterations → every 100

Multi-Stage Optimization

Chain multiple optimization stages with different objectives:

optimization:
  stage1:
    parameters: [dynamics.G]
    loss_function: loss_fc
    max_steps: 100
  stage2:
    parameters: [dynamics.J_i]
    loss_function: loss_bold
    max_steps: 50

State Trajectory

Track parameter evolution during optimization:

trajectory = results.optimization.global_optimization.state_trajectory
for state in trajectory:
    print(f"G={state.dynamics.G}")

See Also