from tvbo import Dynamics, SimulationExperiment
model = Dynamics.from_db("Generic2dOscillator")
exp = SimulationExperiment(dynamics=model)Code Generation & Export
Generate standalone simulation code in multiple languages and export experiments to standard formats.
Code Generation
View the generated source code for any backend:
JAX
print(exp.render_code('jax'))
import jax
from tvbo.data.types import TimeSeries
from tvbo.utils import Bunch
import jax.numpy as jnp
def cfun(weights, history, current_state, p, delay_indices, t):
n_node = weights.shape[0]
b, a = p.b, p.a
x_j = jnp.array([
current_state[0, delay_indices[1]],
])
V = x_j[0]
pre = x_j
pre = pre.reshape(-1, n_node ,n_node)
def op(x): return jnp.sum(weights * x, axis=-1)
gx = jax.vmap(op, in_axes=0)(pre)
return b + a*gx
import jax.numpy as jnp
import jax.scipy as jsp
def dfun(current_state, t, cX, _p):
# Parameters
I = _p.I
a = _p.a
alpha = _p.alpha
b = _p.b
beta = _p.beta
c = _p.c
d = _p.d
e = _p.e
f = _p.f
g = _p.g
gamma = _p.gamma
tau = _p.tau
# Coupling
c_glob = cX[0]
local_coupling = 0
# State variables
V = current_state[0]
W = current_state[1]
# Derivatives
dV_dt = d*tau*(I*gamma + V*g + V*local_coupling + W*alpha + c_glob*gamma + e*V**2 - f*V**3)
dW_dt = d*(a + V*b + c*V**2 - W*beta)/tau
derivatives = jnp.array([dV_dt, dW_dt])
return derivatives
def integrate(state, weights, dt, params_integrate, delay_indices, external_input):
"""
Heun Integration
================
"""
t, _ = external_input
noise = 0
params_dfun, params_cfun, params_stimulus = params_integrate
history, current_state = state
stimulus = 0
cX = jax.vmap(cfun, in_axes=(None, -1, -1, None, None, None), out_axes=-1)(weights, history, current_state, params_cfun, delay_indices, t)
dX0 = dfun(current_state, t, cX, params_dfun)
X = current_state
# Calculate intermediate step X1
X1 = X + dX0 * dt + noise + stimulus * dt
# Calculate derivative X1
dX1 = dfun(X1, t, cX, params_dfun)
# Calculate the state change dX
dX = (dX0 + dX1) * (dt / 2)
next_state = current_state + (dX)
return (history, next_state), next_state
import jax
import jax.numpy as jnp
def g(dt, nt, n_svar, n_nodes, n_modes, seed=0, sigma_vec=None, sigma=0.0, state=None):
"""Standard Gaussian white noise using xi ~ N(0,1).
Returns (nt, n_svar, n_nodes, n_modes): sqrt(dt) * sigma * xi.
- sigma_vec: optional per-state sigma (length n_svar).
- sigma: scalar fallback when sigma_vec is None.
- state: optional current state placeholder for future correlative noise.
"""
key = jax.random.PRNGKey(int(seed))
xi = jax.random.normal(key, (nt, n_svar, n_nodes, n_modes))
if sigma_vec is not None:
sigma_b = jnp.asarray(sigma_vec)[None, ..., None, None]
else:
sigma_b = jnp.asarray(sigma)
noise = jnp.sqrt(dt) * sigma_b * xi
return noise
def monitor_raw(time_steps, trace, params, t_offset = 0):
dt = 0.01220703125
return TimeSeries(time=(time_steps + t_offset) * dt, data=trace, title = "Raw")
def transform_parameters(_p):
I, a, alpha, b, beta, c, d, e, f, g, gamma, tau = _p.I, _p.a, _p.alpha, _p.b, _p.beta, _p.c, _p.d, _p.e, _p.f, _p.g, _p.gamma, _p.tau
return _p
c_vars = jnp.array([0]).astype(jnp.int32)
def kernel(state):
# problem dimensions
n_nodes = 1
n_svar = 2
n_cvar = 1
n_modes = 1
nh = 1
current_state, history = (state.initial_conditions.data[-1], None) ## history = current_state
ics = (history, current_state)
weights = state.network.weights_matrix
dn = jnp.arange(int(n_nodes)) * jnp.ones((int(n_nodes), int(n_nodes))).astype(jnp.int32)
idelays = jnp.round(state.network.lengths_matrix / state.network.conduction_speed.value / state.dt).astype(jnp.int32) if state.network.conduction_speed.value > 0 else jnp.zeros((int(n_nodes), int(n_nodes)), dtype=jnp.int32)
di = -1 * idelays - 1
delay_indices = (di, dn)
dt = state.dt
nt = state.nt
time_steps = jnp.arange(0, nt)
# Generate batch noise using xi with per-state sigma_vec.
# Prefer state-provided sigma_vec (supports vmapped sweeps); fallback to experiment-level constants.
seed = getattr(state.noise, 'seed', 0) if hasattr(state.noise, 'seed') else 0
try:
sigma_vec_runtime = getattr(state.noise, 'sigma_vec', None)
except Exception:
sigma_vec_runtime = None
sigma_vec = sigma_vec_runtime if sigma_vec_runtime is not None else jnp.array([0.,0.])
noise = g(dt, nt, n_svar, n_nodes, n_modes, seed=seed, sigma_vec=sigma_vec)
p = transform_parameters(state.parameters.dynamics)
params_integrate = (p, state.parameters.coupling, state.stimulus)
op = lambda ics, external_input: integrate(ics, weights, dt, params_integrate, delay_indices, external_input)
latest_carry, res = jax.lax.scan(op, ics, (time_steps, noise))
trace = res
trace = jnp.hstack((
trace[:, [0], :],
trace[:, [1], :],
))
t_offset = 0
time_steps = time_steps + 1
labels_dimensions = {
"Time": None,
"State Variable": ['V', 'W'],
"Space": ['0'],
"Mode": ['m0'],
}
return TimeSeries(time=(time_steps + t_offset) * dt, data=trace, title = "Raw", sample_period=dt, labels_dimensions=labels_dimensions)
Python (SciPy)
print(model.render_code('python'))import numpy as np
import scipy.special
def Generic2dOscillator(
current_state,
t,
I=0.0,
a=-2.0,
alpha=1.0,
b=-10.0,
beta=1.0,
c=0.0,
d=0.02,
e=3.0,
f=1.0,
g=0.0,
gamma=1.0,
tau=1.0,
c_glob=0.0,
stimulus=False,
):
stim_t = stimulus(t) if stimulus else 0.0
local_coupling = 0.0
# State variables
V = current_state[0]
W = current_state[1]
# Derivatives
dV_dt = (
d
* tau
* (
I * gamma
+ V * g
+ V * local_coupling
+ W * alpha
+ c_glob * gamma
+ e * V**2
- f * V**3
)
)
dW_dt = d * (a + V * b + c * V**2 - W * beta) / tau
derivatives = np.array([dV_dt, dW_dt])
return derivatives
Julia
print(model.render_code('julia'))
function Generic2dOscillator!(dx, x, p, t = 0)
(;I, a, alpha, b, beta, c, d, e, f, g, gamma, tau, c_glob, local_coupling) = p
V, W = x
dx[1] = d .* tau .* (I .* gamma + V .* g + V .* local_coupling + W .* alpha + c_glob .* gamma + e .* V .^ 2 - f .* V .^ 3)
dx[2] = d .* (a + V .* b + c .* V .^ 2 - W .* beta) ./ tau
dx
end
# Parameter values
p = (I = 0.0, a = -2.0, alpha = 1.0, b = -10.0, beta = 1.0, c = 0.0, d = 0.02, e = 3.0, f = 1.0, g = 0.0, gamma = 1.0, tau = 1.0, c_glob = 0.0, local_coupling = 0.0)
using OrdinaryDiffEqTsit5
# Initial conditions (scalar state vector)
u0 = [
0.1, # Initial value for V
0.1, # Initial value for W
]
# Define time span
tspan = (0.0, 1000) # Adjust time span as needed
prob = ODEProblem(Generic2dOscillator!, u0, tspan, p)
# Solve
sol = solve(prob, Tsit5(); saveat=0.01)
All Code Generation Targets
From Dynamics.render_code(format)
| Format | Output |
|---|---|
'tvb' |
TVB model class |
'python' / 'scipy' |
SciPy ODE function |
'autodiff' / 'jax' |
JAX dfun kernel |
'julia' |
DifferentialEquations.jl |
'bifurcation-julia' |
BifurcationKit.jl |
'bifurcation-numcont' |
MatCont-style |
'bifurcation-auto7p' |
AUTO-07p |
'pde-fem' |
scikit-fem PDE solver |
From SimulationExperiment.render_code(format)
| Format | Output |
|---|---|
'jax' / 'autodiff' |
Full JAX simulation |
'tvboptim' |
tvboptim experiment |
'tvb' |
TVB model class |
'networkdynamics' |
NetworkDynamics.jl |
'mtk' |
ModelingToolkit.jl |
'rateml' |
RateML Python + Numba |
'cuda' |
CUDA kernel |
'pde' |
scikit-fem PDE |
Export Formats
YAML
yaml_str = exp.to_yaml()
print(yaml_str[:300])id: 1
model: Generic2dOscillator
dynamics:
name: Generic2dOscillator
iri: tvbo:Generic2dOscillator
parameters:
I:
name: I
definition: Baseline shift of the cubic nullcline
value: 0.0
domain:
lo: '-5.0'
hi: '5.0'
step: '0.01'
log_scale
BIDS
Export as a BEP034 computational modeling derivative:
exp.to_bids("output/derivatives/tvbo")OpenMINDS
exp.to_openminds("output/experiment.jsonld")Model Reports
Generate formatted documentation of a model:
from IPython.display import Markdown
report = model.generate_report(format="markdown")
Markdown(report)Generic2dOscillator
The Generic 2-Dimensional Oscillator (G2D) is a phenomenological, coupled, nonlinear two-dimensional (i.e., two state-variables (‘V’, ‘W’)) oscillatory, neural mass model. The G2D is a generalization of the well-known FitzHugh-Nagumo model (FitzHugh, 1961; Nagumo et. al, 1962), adapted here for reproducing a wilder class of dynamical configurations of physiological phenomena as observed in neuronal population using phase-portrait method.
State Equations
\[ \dot{V} = d*\tau*\left(I*\gamma + V*g + V*c_{local} + W*\alpha + c_{glob}*\gamma + e*V^{2} - f*V^{3}\right) \] \[ \dot{W} = \frac{d*\left(a + V*b + c*V^{2} - W*\beta\right)}{\tau} \]
Parameters
| Parameter | Value | Unit | Description |
|---|---|---|---|
| \(I\) | 0.0 | N/A | Baseline shift of the cubic nullcline |
| \(a\) | -2.0 | N/A | Vertical shift of the configurable nullcline |
| \(\alpha\) | 1.0 | N/A | Constant parameter to scale the rate of feedback from the slow variable to the fast variable. |
| \(b\) | -10.0 | N/A | Linear slope of the configurable nullcline |
| \(\beta\) | 1.0 | N/A | Constant parameter to scale the rate of feedback from the slow variable to itself |
| \(c\) | 0.0 | N/A | Parabolic term of the configurable nullcline |
| \(d\) | 0.02 | N/A | Temporal scale factor |
| \(e\) | 3.0 | N/A | Coefficient of the quadratic term of the cubic nullcline |
| \(f\) | 1.0 | N/A | Coefficient of the cubic term of the cubic nullcline |
| \(g\) | 0.0 | N/A | Coefficient of the linear term of the cubic nullcline |
| \(\gamma\) | 1.0 | N/A | Constant parameter to reproduce FHN dynamics where excitatory input currents are negative |
| \(\tau\) | 1.0 | N/A | A time-scale hierarchy can be introduced for the state variables :math:V and :math:W |
References
Nagumo, J., Arimoto, S., & Yoshizawa, S. (1962). An active pulse transmission line simulating nerve axon. Proceedings of the IRE, 50(10), 2061-2070.
FitzHugh, R. (1961). Impulses and physiological states in theoretical models of nerve membrane. Biophysical Journal, 1(6), 445-466.
See Also
- Simulation Experiments — running experiments
- BIDS Export — detailed BIDS workflow
- OpenMINDS — OpenMINDS JSON-LD export