Reduced Wong Wang BOLD FC Optimization

Imports
# Set up environment
import os
import time
# Mock devices to force JAX to parallelize on CPU
cpu = True
if cpu:
    N = 8
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'

# Import all required libraries
from scipy import io
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import copy
import optax

# Import from tvboptim
from tvboptim import jaxify
from tvboptim.types import Parameter, GridSpace
from tvboptim.types.stateutils import show_free_parameters
from tvboptim.utils import set_cache_path, cache
from tvboptim import observation as obs
from tvboptim.execution import ParallelExecution, SequentialExecution
from tvboptim.optim.optax import OptaxOptimizer
from tvboptim.optim.callbacks import MultiCallback, DefaultPrintCallback, SavingCallback

# Import from tvbo
from tvbo.export.experiment import SimulationExperiment
from tvbo.datamodel import tvbo_datamodel
from tvbo.utils import numbered_print

# Set cache path for tvboptim
set_cache_path("./example_cache_rww")

Create a TVB-O Simulation Experiment

experiment = SimulationExperiment(
    model = {
        "name": "ReducedWongWang", 
        "parameters": {
            "w": {"name": "w", "value": 0.5},
            "I_o": {"name": "I_o", "value": 0.32}, 
        }
    },
    connectivity = {
        "parcellation": {"atlas": {"name": "DesikanKilliany"}}, 
        "conduction_speed": {"name": "cs", "value": np.array([np.inf])}
    },
    coupling = {
        "name": "Linear", 
        "parameters": {"a": {"name": "a", "value": 0.5}}
    },
    integration={
        "method": "Heun",
        "step_size": 4.0,
        "noise": {"parameters": {"sigma": {'value': 0.00283}}},
        "duration": 2 * 60_000
    },
    monitors={
        "Raw": {"name": "Raw"}, 
        "Bold": {"name": "Bold", "period": 1000.0}},
)

Model Functions

numbered_print(experiment.render_code(format = "jax", scalar_pre = True))
001 
002 import jax.scipy.signal as sig
003 from collections import namedtuple
004 import jax
005 from tvbo.data.types import TimeSeries
006 import jax.numpy as jnp
007 import jax.scipy as jsp
008 
009 
010 def cfun(weights, history, current_state, p, delay_indices, t):
011     n_node = weights.shape[0]
012     a, b = p.a, p.b
013 
014     x_j = jnp.array([
015 
016         current_state[0],
017 
018     ])
019 
020     pre = x_j
021 
022     def op(x): return weights @ x
023     gx = jax.vmap(op, in_axes=0)(pre)
024     return b + a*gx
025 
026 
027 def dfun(current_state, cX, _p, local_coupling=0):
028     w, I_o, J_N, a, b, d, gamma, tau_s = _p.w, _p.I_o, _p.J_N, _p.a, _p.b, _p.d, _p.gamma, _p.tau_s
029     # unpack coupling terms and states as in dfun
030     c_pop0 = cX[0]
031 
032     S = current_state[0]
033 
034     # compute internal states for dfun
035     x = I_o + J_N*c_pop0 + J_N*S*local_coupling + J_N*S*w
036     H = (-b + a*x)/(1 - jnp.exp(-d*(-b + a*x)))
037 
038     return jnp.array([
039         -S/tau_s + H*gamma*(1 - S),  # S
040     ])
041 
042 
043 def integrate(state, weights, dt, params_integrate, delay_indices, external_input):
044     """
045     Heun Integration
046     ================
047     """
048     t, noise = external_input
049 
050     params_dfun, params_cfun, params_stimulus = params_integrate
051 
052     history, current_state = state
053     stimulus = 0
054 
055     inf = jnp.inf
056     min_bounds = jnp.array([[[0.0]]])
057     max_bounds = jnp.array([[[1.0]]])
058 
059     cX = jax.vmap(cfun, in_axes=(None, -1, -1, None, None, None), out_axes=-
060                   1)(weights, history, current_state, params_cfun, delay_indices, t)
061 
062     dX0 = dfun(current_state, cX, params_dfun)
063 
064     X = current_state
065 
066     # Calculate intermediate step X1
067     X1 = X + dX0 * dt + noise + stimulus * dt
068     X1 = jnp.clip(X1, min_bounds, max_bounds)
069 
070     # Calculate derivative X1
071     dX1 = dfun(X1, cX, params_dfun)
072     # Calculate the state change dX
073     dX = (dX0 + dX1) * (dt / 2)
074     next_state = current_state + (dX) + noise
075     next_state = jnp.clip(next_state, min_bounds, max_bounds)
076 
077     return (history, next_state), next_state
078 
079 
080 timeseries = namedtuple("timeseries", ["time", "trace"])
081 
082 
083 def monitor_raw_0(time_steps, trace, params, t_offset=0):
084     dt = 4.0
085     return TimeSeries(time=(time_steps + t_offset) * dt, data=trace, title="Raw")
086 
087 
088 def monitor_temporal_average_1(time_steps, trace, params, t_offset=0):
089     dt = 4.0
090     voi = jnp.array([0])
091     istep = 1
092     t_map = time_steps[::istep] - 1
093 
094     def op(ts):
095         start_indices = (ts,) + (0,) * (trace.ndim - 1)
096         slice_sizes = (istep,) + voi.shape + trace.shape[2:]
097         return jnp.mean(jax.lax.dynamic_slice(trace[:, voi, :], start_indices, slice_sizes), axis=0)
098     vmap_op = jax.vmap(op)
099     trace_out = vmap_op(t_map)
100 
101     idxs = jnp.arange(((istep - 2) // 2), time_steps.shape[0], istep)
102     return TimeSeries(time=(time_steps[idxs]) * dt, data=trace_out[0:idxs.shape[0], :, :], title="TemporalAverage")
103 
104 
105 exp, sin, sqrt = jnp.exp, jnp.sin, jnp.sqrt
106 
107 
108 def monitor_bold_1(time_steps, trace, params, t_offset=0):
109     # downsampling via temporal average / subsample
110     dt = 4.0
111     voi = jnp.array([0])
112     period = 1000.0  # sampling period of the BOLD Monitor in ms
113     istep_int = 1  # steps taken by the averaging/subsampling monitor to get an interim period of 4 ms
114     istep = 250
115     final_istep = 250  # steps to take on the downsampled signal
116 
117     res = monitor_temporal_average_1(time_steps, trace, None)
118     time_steps_i = res.time
119     trace_new = res.data
120 
121     time_steps_new = time_steps[jnp.arange(
122         istep-1, time_steps.shape[0], istep)]
123 
124     # hemodynamic response function
125     tau_s = params.tau_s
126     tau_f = params.tau_f
127     k_1 = params.k_1
128     V_0 = params.V_0
129     stock = params.stock
130 
131     trace_new = jnp.vstack([stock, trace_new])
132 
133     def op(var): return 1/3. * exp(-0.5*(var / tau_s)) * (sin(sqrt(1. /
134                                                                    tau_f - 1./(4.*tau_s**2)) * var)) / (sqrt(1./tau_f - 1./(4.*tau_s**2)))
135     stock_steps = 5000
136     stock_time_max = 20.0  # stock time has to be in seconds
137     stock_time_step = stock_time_max / stock_steps
138     stock_time = jnp.arange(0.0, stock_time_max, stock_time_step)
139     hrf = op(stock_time)
140 
141     # Convolution along time axis
142     # via fft
143     def op1(x): return sig.fftconvolve(x, hrf, mode="valid")
144 
145     def op2(x): return jax.vmap(op1, in_axes=(
146         1), out_axes=(1))(x)  # map over nodes
147     def op3(x): return jax.vmap(op2, in_axes=(1), out_axes=(1))(
148         x)  # map over state variables
149     bold = jax.vmap(op3, in_axes=(3), out_axes=(3))(
150         trace_new)  # map over modes
151 
152     bold = k_1 * V_0 * (bold - 1.0)
153 
154     bold_idx = jnp.arange(
155         final_istep-2, time_steps_i.shape[0], final_istep)[0:time_steps_new.shape[0]] + 1
156     return TimeSeries(time=(time_steps_new + t_offset) * dt, data=bold[bold_idx, :, :], title="BOLD")
157 
158 
159 def transform_parameters(_p):
160     w, I_o, J_N, a, b, d, gamma, tau_s = _p.w, _p.I_o, _p.J_N, _p.a, _p.b, _p.d, _p.gamma, _p.tau_s
161 
162     return _p
163 
164 
165 c_vars = jnp.array([0])
166 
167 
168 def kernel(state):
169     # problem dimensions
170     n_nodes = 84
171     n_svar = 1
172     n_cvar = 1
173     n_modes = 1
174     nh = 1
175 
176     # history = current_state
177     current_state, history = (state.initial_conditions.data[-1], None)
178 
179     ics = (history, current_state)
180     weights = state.connectivity.weights
181 
182     dn = jnp.arange(n_nodes) * jnp.ones((n_nodes, n_nodes)).astype(int)
183     idelays = jnp.round(state.connectivity.lengths /
184                         state.connectivity.metadata.conduction_speed.value / state.dt).astype(int)
185     di = -1 * idelays - 1
186     delay_indices = (di, dn)
187 
188     dt = state.dt
189     nt = state.nt
190     time_steps = jnp.arange(0, nt)
191 
192     key = jax.random.PRNGKey(state.noise.metadata.seed)
193     _noise = jax.random.normal(key, (nt, n_svar, n_nodes, n_modes))
194     noise = (jnp.sqrt(dt) * state.noise.sigma[None, ..., None, None]) * _noise
195 
196     p = transform_parameters(state.parameters.model)
197     params_integrate = (p, state.parameters.coupling, state.stimulus)
198 
199     def op(ics, external_input): return integrate(ics, weights,
200                                                   dt, params_integrate, delay_indices, external_input)
201 
202     latest_carry, res = jax.lax.scan(op, ics, (time_steps, noise))
203 
204     trace = res
205 
206     t_offset = 0
207     time_steps = time_steps + 1
208 
209     params_monitors = state.monitor_parameters
210     result = [monitor_raw_0(time_steps, trace, params_monitors[0], t_offset=t_offset),
211               monitor_bold_1(time_steps, trace,
212                              params_monitors[1], t_offset=t_offset),
213               ]
214 
215     return result

Run Initial Simulation - Update Inital Conditions

# Run the model and get results
result = model(state)

# Use first result as initial conditions for second run
state.initial_conditions = result[0]
# select last 5000 steps as BOLD stock
state.monitor_parameters[1]["stock"] = result[0].data[-5000:]
result2 = model(state)

Define Observations and Loss

def observation(state):
    bold = model(state)[1]
    return obs.fc(bold, skip_t = 20)

def loss(state):
    fc = observation(state)
    # return 1 - obs.fc_corr(fc, fc_target)
    return obs.rmse(fc, fc_target)

Parameter Exploration

# Set up parameter ranges for exploration
state.parameters.model.w.free = True
state.parameters.model.w.low = 0.001
state.parameters.model.w.high = 0.7

state.parameters.coupling.a.free = True
state.parameters.coupling.a.low = 0.001
state.parameters.coupling.a.high = 0.7
show_free_parameters(state)

# Create grid for parameter exploration
# n = 32
n = 64
# _params = copy.deepcopy(state)
# _params.nt = 10_000  # 10s simulation for better frequency resolution
params_set = GridSpace(state, n=n)

@cache("explore", redo = False)
def explore():
    exec = ParallelExecution(loss, params_set, n_pmap=8)
    # Alternative: Sequential execution
    # exec = SequentialExecution(loss, params_set)
    return exec.run()

results = explore()

Run Optimization

# Create and run optimizer
cb = MultiCallback([
    DefaultPrintCallback(every=10),
    SavingCallback(key = "state", save_fun = lambda *args: args[1]) # save updated state
])

@cache("optimize", redo = False)
def optimize():
    opt = OptaxOptimizer(loss, optax.adam(0.01, b2 = 0.9999), callback = cb)
    fitted_state, fitting_data = opt.run(state, max_steps=400)
    return fitted_state, fitting_data

fitted_state, fitting_data = optimize()

Refine Optimization by setting Regional Parameters

# Copy already optimized state and turn parameters regional
_fitted_state = copy.deepcopy(fitted_state)

_fitted_state.parameters.model.w.free = True
_fitted_state.parameters.model.w.value = jnp.broadcast_to(_fitted_state.parameters.model.w.value, (84,1))
_fitted_state.parameters.model.I_o.free = True
_fitted_state.parameters.model.I_o.value = jnp.broadcast_to(_fitted_state.parameters.model.I_o.value, (84,1))
_fitted_state.parameters.coupling.a.free = False
@cache("optimize_het", redo = False)
def optimize():
    opt = OptaxOptimizer(loss, optax.adam(0.004, b2 = 0.999), callback=cb)
    fitted_state, fitting_data = opt.run(_fitted_state, max_steps=200)
    return fitted_state, fitting_data

fitted_state_het, fitting_data_het = optimize()
Loading optimize_het from cache, last modified 2025-06-17 14:09:26.288373