import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from tvbo import SimulationExperiment, Network
from tvboptim.data import load_structural_connectivity
import bsplot
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
# Load experiment and network
exp = SimulationExperiment.from_file(
"/Users/leonmartin_bih/tools/tvbo/database/experiments/JR_MEG_FrequencyGradient_Optimization.yaml"
)
print(exp.render_code('tvboptim'))Jansen-Rit MEG Frequency Gradient Optimization via tvboptim
Reproducing MEG resting-state frequency gradients using network dynamics. Fits region-specific Jansen-Rit parameters (a, b) to match target peak frequencies from visual cortex (11 Hz) to association areas (7 Hz).
weights, lengths, region_labels = load_structural_connectivity(name="dk_average")
# Load atlas for brain mapping
dk_info = pd.read_csv(
"/Users/leonmartin_bih/work_data/toolboxes/tvboptim/docs/data/dk_average/fs_default_freesurfer_idx.csv"
)
dk = nib.load(
"/Users/leonmartin_bih/work_data/toolboxes/tvboptim/docs/data/dk_average/aparc+aseg-mni_09c.nii.gz"
)
dk_info.drop_duplicates(subset="freesurfer_idx", inplace=True)
region_labels = np.array(region_labels)
idx_l = np.where(region_labels == "L.LOG")[0]
idx_r = np.where(region_labels == "R.LOG")[0]
dist_from_vc = np.array(np.squeeze(0.5 * (lengths[idx_l, :] + lengths[idx_r, :])))
f_min, f_max = 7, 11
delta_f = (f_max - f_min) / (dist_from_vc.max() - dist_from_vc.min())
target_peak_freqs = f_max - delta_f * (dist_from_vc - dist_from_vc.min())
# Run initial simulation to get frequency axis from actual PSD computation
ns = exp.execute("tvboptim")
sim_init = exp.run("tvboptim", mode="simulation")
frequencies = np.array(sim_init.integration.main.observations.simulated_psd.frequencies)
# Compute target PSDs as Cauchy distributions using actual frequency axis
gamma = 1.0
target_psd = np.array(
[
1 / (np.pi * gamma * (1 + ((frequencies - f0) / gamma) ** 2))
for f0 in target_peak_freqs
])
============================================================
STEP 1: Running simulation...
============================================================
/Users/leonmartin_bih/tools/tvbo/.venv/lib/python3.13/site-packages/jax/_src/third_party/scipy/signal_helper.py:47: UserWarning: nperseg=256 is greater than input_length=100, using nperseg=100
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
Simulation period: 1000.0 ms, dt: 1.0 ms
Transient period: 20000.0 ms
Simulation complete.
============================================================
Experiment complete.
============================================================
## Run complete workflow with matching target
results = exp.run("tvboptim", target=target_psd)
============================================================
STEP 1: Running simulation...
============================================================
/Users/leonmartin_bih/tools/tvbo/.venv/lib/python3.13/site-packages/jax/_src/third_party/scipy/signal_helper.py:47: UserWarning: nperseg=256 is greater than input_length=100, using nperseg=100
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
Simulation period: 1000.0 ms, dt: 1.0 ms
Transient period: 20000.0 ms
Simulation complete.
============================================================
STEP 2: Running explorations...
============================================================
> frequency_landscape
Explorations complete.
============================================================
STEP 4: Running optimization...
============================================================
Step 0: 0.563778
/Users/leonmartin_bih/tools/tvbo/.venv/lib/python3.13/site-packages/jax/_src/third_party/scipy/signal_helper.py:47: UserWarning: nperseg=256 is greater than input_length=100, using nperseg=100
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
Step 10: 0.276656
Step 20: 0.163564
Step 30: 0.129762
Step 40: 0.105220
Step 50: 0.091368
Step 60: 0.085298
Step 70: 0.082105
Step 80: 0.078121
Step 90: 0.070131
Step 100: 0.068467
Step 110: 0.066541
Step 120: 0.065615
Step 130: 0.064932
Step 140: 0.064446
Step 150: 0.063644
Optimization complete.
============================================================
Experiment complete.
============================================================
/Users/leonmartin_bih/tools/tvbo/.venv/lib/python3.13/site-packages/jax/_src/third_party/scipy/signal_helper.py:47: UserWarning: nperseg=256 is greater than input_length=100, using nperseg=100
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
Results
def build_freq_map(freqs):
fmap = np.zeros_like(dk.get_fdata())
for i, name in enumerate(region_labels):
idx = dk_info[dk_info.acronym == name].freesurfer_idx.values[0]
fmap = np.where(dk.get_fdata() == idx, freqs[i], fmap)
return fmap
mosaic = """
AABBCC
DDEEFF
"""
fig, axes = plt.subplot_mosaic(mosaic, figsize=(16, 9), constrained_layout=True)
cmap = plt.cm.cividis
t_max = int(1000 / exp.integration.step_size)
# A: Dynamics with horizontal split (top=transient, bottom=stabilized)
ax = axes["A"]
ax.axis("off")
ax.set_title("A. Network Dynamics", fontsize=10)
# Top inset: Transient
ax_top = ax.inset_axes([0, 0.52, 1, 0.48])
data = results.integration.transient.data[:t_max, 0, :]
norm = Normalize(vmin=data.mean(0).min(), vmax=data.mean(0).max())
for i in range(data.shape[1]):
ax_top.plot(
results.integration.transient.time[:t_max],
data[:, i],
color=cmap(norm(data.mean(0)[i])),
lw=0.5,
)
ax_top.text(
0.95,
0.9,
"Transient",
transform=ax_top.transAxes,
ha="right",
va="top",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
fontsize=9,
)
ax_top.set(ylabel="y0 [a.u.]")
ax_top.set_xticklabels([])
# Bottom inset: Stabilized
ax_bot = ax.inset_axes([0, 0, 1, 0.48])
data = results.integration.main.data[:t_max, 0, :]
norm = Normalize(vmin=data.mean(0).min(), vmax=data.mean(0).max())
for i in range(data.shape[1]):
ax_bot.plot(
results.integration.main.time[:t_max],
data[:, i],
color=cmap(norm(data.mean(0)[i])),
lw=0.5,
)
ax_bot.text(
0.95,
0.9,
"Stabilized",
transform=ax_bot.transAxes,
ha="right",
va="top",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
fontsize=9,
)
ax_bot.set(xlabel="Time [ms]", ylabel="y0 [a.u.]")
# B: Power Spectral Density
ax = axes["B"]
ax.plot(
results.integration.main.observations.simulated_psd.frequencies,
results.integration.main.observations.simulated_psd.psd.squeeze().T,
lw=1,
color="k",
alpha=0.15,
)
ax.plot(
results.integration.main.observations.simulated_psd.frequencies,
results.integration.main.observations.simulated_psd.psd.squeeze().mean(axis=0),
lw=2,
color="k",
)
peak = results.integration.main.observations.simulated_psd.frequencies[
np.argmax(
results.integration.main.observations.simulated_psd.psd.squeeze().mean(axis=0)
)
]
ax.text(
0.95,
0.95,
f"Peak @ {peak:.1f} Hz",
transform=ax.transAxes,
ha="right",
va="top",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
)
ax.set(
xlabel="Frequency [Hz]",
ylabel="Power",
xlim=(0, 20),
yscale="log",
title="B. Power Spectral Density",
)
# C: Parameter Landscape with Fitted Points
ax = axes["C"]
expl = results.exploration.frequency_landscape
pc = expl.grid.collect()
n1 = exp.explorations["frequency_landscape"].parameters["a"].domain.n
n2 = exp.explorations["frequency_landscape"].parameters["b"].domain.n
im = ax.imshow(
expl.results.reshape(n1, n2).T,
cmap="cividis",
extent=[
pc.dynamics.a.min(),
pc.dynamics.a.max(),
pc.dynamics.b.min(),
pc.dynamics.b.max(),
],
origin="lower",
aspect="auto",
)
plt.colorbar(im, ax=ax, label="Peak Frequency [Hz]", shrink=0.8)
a_fit = results.optimization.spectral_gradient_fit.fitted_params.dynamics.a
b_fit = results.optimization.spectral_gradient_fit.fitted_params.dynamics.b
a_vals = a_fit.value if hasattr(a_fit, "value") else a_fit
b_vals = b_fit.value if hasattr(b_fit, "value") else b_fit
ax.scatter(
a_vals, b_vals, color="white", s=20, edgecolors="black", lw=0.5, zorder=5, alpha=0.7
)
ax.set(
xlabel=r"a [ms$^{-1}$]",
ylabel=r"b [ms$^{-1}$]",
title="C. Parameter Landscape + Fitted",
)
# D: Scatter plot before optimization
ax = axes["D"]
ax.scatter(
target_peak_freqs,
results.integration.main.observations.peak_frequencies,
alpha=0.6,
s=25,
color="royalblue",
edgecolors="k",
lw=0.5,
)
ax.plot([7, 11], [7, 11], "k--", lw=1.5)
ax.set(
xlabel="Target [Hz]",
ylabel="Simulated [Hz]",
xlim=(6.5, 11.5),
ylim=(6.5, 11.5),
aspect="equal",
title="D. Before Optimization",
)
# E: Scatter plot after optimization
ax = axes["E"]
ax.scatter(
target_peak_freqs,
results.optimization.spectral_gradient_fit.simulation.observations.peak_frequencies,
alpha=0.6,
s=25,
color="royalblue",
edgecolors="k",
lw=0.5,
)
ax.plot([7, 11], [7, 11], "k--", lw=1.5)
ax.set(
xlabel="Target [Hz]",
ylabel="Simulated [Hz]",
xlim=(6.5, 11.5),
ylim=(6.5, 11.5),
aspect="equal",
title="E. After Optimization",
)
# F: Brain Frequency Maps (3 horizontal insets)
ax = axes["F"]
ax.axis("off")
ax.set_title("F. Brain Frequency Maps", fontsize=10)
norm_brain = Normalize(vmin=6.5, vmax=11.5)
brain_axes = [ax.inset_axes([i / 3, 0.1, 1 / 3, 0.8]) for i in range(3)]
for bax, freqs, title in zip(
brain_axes,
[
results.integration.main.observations.peak_frequencies,
target_peak_freqs,
results.optimization.spectral_gradient_fit.simulation.observations.peak_frequencies,
],
["Initial", "Target", "Fitted"],
):
bsplot.brain.glass_brain(
nib.Nifti1Image(build_freq_map(freqs), dk.affine),
cmap="cividis_r",
view="horizontal",
threshold=0,
norm=norm_brain,
ax=bax,
)
bax.set_title(title, fontsize=9)
bax.axis("off")
sm = plt.cm.ScalarMappable(cmap="cividis_r", norm=norm_brain)
fig.colorbar(sm, ax=ax, orientation="vertical", shrink=0.6, label="Peak Freq [Hz]")
plt.suptitle(
"Jansen-Rit MEG Frequency Gradient Optimization",
fontsize=14,
fontweight="bold",
y=1.05,
)
bsplot.style.format_fig(fig)