Code Generation & Export

Generate standalone simulation code in multiple languages and export experiments to standard formats.

Code Generation

View the generated source code for any backend:

from tvbo import Dynamics, SimulationExperiment

model = Dynamics.from_db("Generic2dOscillator")
exp = SimulationExperiment(dynamics=model)

JAX

print(exp.render_code('jax'))


import jax
from tvbo.data.types import TimeSeries
from tvbo.utils import Bunch
import jax.numpy as jnp



def cfun(weights, history, current_state, p, delay_indices, t):
    n_node = weights.shape[0]
    b, a = p.b, p.a


    x_j = jnp.array([

    current_state[0, delay_indices[1]],
    
    ])

    V = x_j[0]

    pre = x_j
    pre = pre.reshape(-1, n_node ,n_node)

    def op(x): return jnp.sum(weights * x, axis=-1)
    gx = jax.vmap(op, in_axes=0)(pre)
    return b + a*gx





import jax.numpy as jnp
import jax.scipy as jsp



def dfun(current_state, t, cX, _p):
    # Parameters
    I = _p.I
    a = _p.a
    alpha = _p.alpha
    b = _p.b
    beta = _p.beta
    c = _p.c
    d = _p.d
    e = _p.e
    f = _p.f
    g = _p.g
    gamma = _p.gamma
    tau = _p.tau

    # Coupling
    c_glob = cX[0]
    local_coupling = 0


    # State variables
    V = current_state[0]
    W = current_state[1]


    # Derivatives
    dV_dt = d*tau*(I*gamma + V*g + V*local_coupling + W*alpha + c_glob*gamma + e*V**2 - f*V**3)
    dW_dt = d*(a + V*b + c*V**2 - W*beta)/tau

    derivatives = jnp.array([dV_dt, dW_dt])

    return derivatives







def integrate(state, weights, dt, params_integrate, delay_indices, external_input):
    """
    Heun Integration
    ================
    """
    t, _ = external_input
    noise = 0


    params_dfun, params_cfun, params_stimulus = params_integrate

    history, current_state = state
    stimulus = 0


    cX = jax.vmap(cfun, in_axes=(None, -1, -1, None, None, None), out_axes=-1)(weights, history, current_state, params_cfun, delay_indices, t)

    dX0 = dfun(current_state, t, cX, params_dfun)

    X = current_state



    # Calculate intermediate step X1
    X1 = X + dX0 * dt + noise + stimulus * dt

    # Calculate derivative X1
    dX1 = dfun(X1, t, cX, params_dfun)
    # Calculate the state change dX
    dX = (dX0 + dX1) * (dt / 2)
    next_state = current_state + (dX)


    return (history, next_state), next_state



import jax
import jax.numpy as jnp

def g(dt, nt, n_svar, n_nodes, n_modes, seed=0, sigma_vec=None, sigma=0.0, state=None):
    """Standard Gaussian white noise using xi ~ N(0,1).

    Returns (nt, n_svar, n_nodes, n_modes): sqrt(dt) * sigma * xi.

    - sigma_vec: optional per-state sigma (length n_svar).
    - sigma: scalar fallback when sigma_vec is None.
    - state: optional current state placeholder for future correlative noise.
    """
    key = jax.random.PRNGKey(int(seed))
    xi = jax.random.normal(key, (nt, n_svar, n_nodes, n_modes))

    if sigma_vec is not None:
        sigma_b = jnp.asarray(sigma_vec)[None, ..., None, None]
    else:
        sigma_b = jnp.asarray(sigma)

    noise = jnp.sqrt(dt) * sigma_b * xi
    return noise






def monitor_raw(time_steps, trace, params, t_offset = 0):
    dt = 0.01220703125
    return TimeSeries(time=(time_steps + t_offset) * dt, data=trace, title = "Raw")


def transform_parameters(_p):
    I, a, alpha, b, beta, c, d, e, f, g, gamma, tau = _p.I, _p.a, _p.alpha, _p.b, _p.beta, _p.c, _p.d, _p.e, _p.f, _p.g, _p.gamma, _p.tau
    
    return _p

c_vars = jnp.array([0]).astype(jnp.int32)

def kernel(state):
    # problem dimensions
    n_nodes = 1
    n_svar = 2
    n_cvar = 1
    n_modes = 1
    nh = 1

    current_state, history = (state.initial_conditions.data[-1], None) ## history = current_state

    ics = (history, current_state)
    weights = state.network.weights_matrix

    dn = jnp.arange(int(n_nodes)) * jnp.ones((int(n_nodes), int(n_nodes))).astype(jnp.int32)
    idelays = jnp.round(state.network.lengths_matrix / state.network.conduction_speed.value / state.dt).astype(jnp.int32) if state.network.conduction_speed.value > 0 else jnp.zeros((int(n_nodes), int(n_nodes)), dtype=jnp.int32)
    di = -1 * idelays - 1
    delay_indices = (di, dn)

    dt = state.dt
    nt = state.nt
    time_steps = jnp.arange(0, nt)

    # Generate batch noise using xi with per-state sigma_vec.
    # Prefer state-provided sigma_vec (supports vmapped sweeps); fallback to experiment-level constants.
    seed = getattr(state.noise, 'seed', 0) if hasattr(state.noise, 'seed') else 0
    try:
        sigma_vec_runtime = getattr(state.noise, 'sigma_vec', None)
    except Exception:
        sigma_vec_runtime = None
    sigma_vec = sigma_vec_runtime if sigma_vec_runtime is not None else jnp.array([0.,0.])
    noise = g(dt, nt, n_svar, n_nodes, n_modes, seed=seed, sigma_vec=sigma_vec)


    p = transform_parameters(state.parameters.dynamics)
    params_integrate = (p, state.parameters.coupling, state.stimulus)

    op = lambda ics, external_input: integrate(ics, weights, dt, params_integrate, delay_indices, external_input)
    latest_carry, res = jax.lax.scan(op, ics, (time_steps, noise))

    trace = res


    



    trace = jnp.hstack((
        trace[:, [0], :],
        trace[:, [1], :],
        ))

    t_offset = 0
    time_steps = time_steps + 1

    
    labels_dimensions = {
        "Time": None,
        "State Variable": ['V', 'W'],
        "Space": ['0'],
        "Mode": ['m0'],
    }
    return TimeSeries(time=(time_steps + t_offset) * dt, data=trace, title = "Raw", sample_period=dt, labels_dimensions=labels_dimensions)

Python (SciPy)

print(model.render_code('python'))
import numpy as np
import scipy.special


def Generic2dOscillator(
    current_state,
    t,
    I=0.0,
    a=-2.0,
    alpha=1.0,
    b=-10.0,
    beta=1.0,
    c=0.0,
    d=0.02,
    e=3.0,
    f=1.0,
    g=0.0,
    gamma=1.0,
    tau=1.0,
    c_glob=0.0,
    stimulus=False,
):
    stim_t = stimulus(t) if stimulus else 0.0
    local_coupling = 0.0

    # State variables
    V = current_state[0]
    W = current_state[1]

    # Derivatives
    dV_dt = (
        d
        * tau
        * (
            I * gamma
            + V * g
            + V * local_coupling
            + W * alpha
            + c_glob * gamma
            + e * V**2
            - f * V**3
        )
    )
    dW_dt = d * (a + V * b + c * V**2 - W * beta) / tau

    derivatives = np.array([dV_dt, dW_dt])

    return derivatives

Julia

print(model.render_code('julia'))








function Generic2dOscillator!(dx, x, p, t = 0)

    (;I, a, alpha, b, beta, c, d, e, f, g, gamma, tau, c_glob, local_coupling) = p

    V, W = x




    dx[1] = d .* tau .* (I .* gamma + V .* g + V .* local_coupling + W .* alpha + c_glob .* gamma + e .* V .^ 2 - f .* V .^ 3)
    dx[2] = d .* (a + V .* b + c .* V .^ 2 - W .* beta) ./ tau
    dx
end


# Parameter values
p = (I = 0.0, a = -2.0, alpha = 1.0, b = -10.0, beta = 1.0, c = 0.0, d = 0.02, e = 3.0, f = 1.0, g = 0.0, gamma = 1.0, tau = 1.0, c_glob = 0.0, local_coupling = 0.0)


using OrdinaryDiffEqTsit5

# Initial conditions (scalar state vector)
u0 = [
        0.1, # Initial value for V
        0.1, # Initial value for W
    ]

# Define time span
tspan = (0.0, 1000) # Adjust time span as needed

prob = ODEProblem(Generic2dOscillator!, u0, tspan, p)


# Solve
sol = solve(prob, Tsit5(); saveat=0.01)


All Code Generation Targets

From Dynamics.render_code(format)

Format Output
'tvb' TVB model class
'python' / 'scipy' SciPy ODE function
'autodiff' / 'jax' JAX dfun kernel
'julia' DifferentialEquations.jl
'bifurcation-julia' BifurcationKit.jl
'bifurcation-numcont' MatCont-style
'bifurcation-auto7p' AUTO-07p
'pde-fem' scikit-fem PDE solver

From SimulationExperiment.render_code(format)

Format Output
'jax' / 'autodiff' Full JAX simulation
'tvboptim' tvboptim experiment
'tvb' TVB model class
'networkdynamics' NetworkDynamics.jl
'mtk' ModelingToolkit.jl
'rateml' RateML Python + Numba
'cuda' CUDA kernel
'pde' scikit-fem PDE

Export Formats

YAML

yaml_str = exp.to_yaml()
print(yaml_str[:300])
id: 1
model: Generic2dOscillator
dynamics:
  name: Generic2dOscillator
  iri: tvbo:Generic2dOscillator
  parameters:
    I:
      name: I
      definition: Baseline shift of the cubic nullcline
      value: 0.0
      domain:
        lo: '-5.0'
        hi: '5.0'
        step: '0.01'
        log_scale

BIDS

Export as a BEP034 computational modeling derivative:

exp.to_bids("output/derivatives/tvbo")

OpenMINDS

exp.to_openminds("output/experiment.jsonld")

Model Reports

Generate formatted documentation of a model:

from IPython.display import Markdown
report = model.generate_report(format="markdown")
Markdown(report)

Generic2dOscillator

The Generic 2-Dimensional Oscillator (G2D) is a phenomenological, coupled, nonlinear two-dimensional (i.e., two state-variables (‘V’, ‘W’)) oscillatory, neural mass model. The G2D is a generalization of the well-known FitzHugh-Nagumo model (FitzHugh, 1961; Nagumo et. al, 1962), adapted here for reproducing a wilder class of dynamical configurations of physiological phenomena as observed in neuronal population using phase-portrait method.

State Equations

\[ \dot{V} = d*\tau*\left(I*\gamma + V*g + V*c_{local} + W*\alpha + c_{glob}*\gamma + e*V^{2} - f*V^{3}\right) \] \[ \dot{W} = \frac{d*\left(a + V*b + c*V^{2} - W*\beta\right)}{\tau} \]

Parameters

Parameter Value Unit Description
\(I\) 0.0 N/A Baseline shift of the cubic nullcline
\(a\) -2.0 N/A Vertical shift of the configurable nullcline
\(\alpha\) 1.0 N/A Constant parameter to scale the rate of feedback from the slow variable to the fast variable.
\(b\) -10.0 N/A Linear slope of the configurable nullcline
\(\beta\) 1.0 N/A Constant parameter to scale the rate of feedback from the slow variable to itself
\(c\) 0.0 N/A Parabolic term of the configurable nullcline
\(d\) 0.02 N/A Temporal scale factor
\(e\) 3.0 N/A Coefficient of the quadratic term of the cubic nullcline
\(f\) 1.0 N/A Coefficient of the cubic term of the cubic nullcline
\(g\) 0.0 N/A Coefficient of the linear term of the cubic nullcline
\(\gamma\) 1.0 N/A Constant parameter to reproduce FHN dynamics where excitatory input currents are negative
\(\tau\) 1.0 N/A A time-scale hierarchy can be introduced for the state variables :math:V and :math:W

References

Nagumo, J., Arimoto, S., & Yoshizawa, S. (1962). An active pulse transmission line simulating nerve axon. Proceedings of the IRE, 50(10), 2061-2070.

FitzHugh, R. (1961). Impulses and physiological states in theoretical models of nerve membrane. Biophysical Journal, 1(6), 445-466.

See Also