Reduced Wong-Wang BOLD FC Optimization via tvboptim

Complete BOLD functional connectivity simulation & optimization: simulation, exploration, global and regional parameter fitting.

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from tvbo import SimulationExperiment, Network

exp = SimulationExperiment.from_file(
    "/Users/leonmartin_bih/tools/tvbo/database/experiments/RWW_BOLD_FC_Optimization.yaml"
)
results = exp.run("tvboptim")

============================================================
STEP 1: Running simulation...
============================================================
  Simulation period: 120000.0 ms, dt: 4.0 ms
  Transient period: 120000.0 ms
  Simulation complete.

============================================================
STEP 2: Running explorations...
============================================================
  > parameter_landscape
  Explorations complete.

============================================================
STEP 4: Running optimization...
============================================================
  Multi-stage optimization: 2 stages

>>> Stage 1/2: global_optimization
    Free parameters: w, G
Step 0: 0.291980
Step 10: 0.284520
Step 20: 0.253220
Step 30: 0.255362
Step 40: 0.258842
Step 50: 0.245892
Step 60: 0.245021
Step 70: 0.244523
Step 80: 0.244445
Step 90: 0.244268
Step 100: 0.244100
Step 110: 0.243942
Step 120: 0.243780
Step 130: 0.243609
Step 140: 0.243432
Step 150: 0.243248
Step 160: 0.243059
Step 170: 0.242864
Step 180: 0.242665
Step 190: 0.242463
Step 200: 0.242257
Step 210: 0.242049
Step 220: 0.241840
Step 230: 0.241629
Step 240: 0.241417
Step 250: 0.241205
Step 260: 0.240993
Step 270: 0.240782
Step 280: 0.240571
Step 290: 0.240361

>>> Stage 2/2: regional_optimization
    Free parameters: w, I_o
    Warmup from: global_optimization
Step 0: 0.240153
Step 10: 0.217843
Step 20: 0.193627
Step 30: 0.184668
Step 40: 0.179224
Step 50: 0.170323
Step 60: 0.167877
Step 70: 0.163732
Step 80: 0.161217
Step 90: 0.159087
Step 100: 0.158655
Step 110: 0.158435
Step 120: 0.161714
Step 130: 0.158633
Step 140: 0.157120
Step 150: 0.156907
Step 160: 0.165927
Step 170: 0.156724
Step 180: 0.158230
Step 190: 0.156313

============================================================
  Multi-stage optimization complete
============================================================

============================================================
Experiment complete.
============================================================

Results

# | fig-cap: "**RWW BOLD FC Optimization.** (A) Neural activity. (B) BOLD signal. (C) Parameter landscape with trajectory. (D-F) FC matrices: target, global, regional. (G-H) FC scatter plots. (I-J) Fitted regional parameters."

import bsplot
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import jax.numpy as jnp
import numpy as np
from tvboptim.observations.observation import fc_corr

mosaic = """
AABBCC
DDEEFF
GGHHII
"""
fig, axes = plt.subplot_mosaic(mosaic, layout='tight', figsize=(8, 6))
cmap = plt.cm.cividis

fc_target = results.integration.main.observations.empirical_fc

# A: Neural Activity
ax = axes["A"]
t_max = int(1000 / exp.integration.step_size)
data = results.integration.main.data[:t_max, 0, :]
norm = Normalize(vmin=data.mean(0).min(), vmax=data.mean(0).max())
for i in range(data.shape[1]):
    ax.plot(
        results.integration.main.time[:t_max],
        data[:, i],
        color=cmap(norm(data.mean(0)[i])),
        lw=0.5,
    )
ax.set(xlabel="Time [ms]", ylabel="S [a.u.]", title="Neural Activity")

# B: BOLD Signal
ax = axes["B"]
bold = results.integration.main.observations.bold
data = bold.data[:60, 0, :]
norm = Normalize(vmin=data.mean(0).min(), vmax=data.mean(0).max())
for i in range(data.shape[1]):
    ax.plot(bold.time[:60], data[:, i], color=cmap(norm(data.mean(0)[i])), lw=0.8)
ax.set(xlabel="Time [s]", ylabel="BOLD [a.u.]", title="BOLD Signal")

# C: Parameter Landscape with Trajectory
ax = axes["C"]
expl = results.exploration.parameter_landscape
pc = expl.grid.collect()
G_vals, w_vals = pc.coupling.FastLinearCoupling.G.flatten(), pc.dynamics.w.flatten()
im = ax.imshow(
    jnp.stack(expl.results).reshape(32, 32).T,
    cmap="cividis_r",
    extent=[G_vals.min(), G_vals.max(), w_vals.min(), w_vals.max()],
    origin="lower",
    aspect="auto",
)
plt.colorbar(im, ax=ax, label="Loss", shrink=0.8)

# Initial & fitted points
for label, G, w, dy in [
    (
        "Initial",
        results.state.coupling.FastLinearCoupling.G,
        results.state.dynamics.w,
        0.03,
    ),
    (
        "Optimized",
        results.global_optimization.fitted_params.coupling.FastLinearCoupling.G.value,
        results.global_optimization.fitted_params.dynamics.w.value,
        -0.05,
    ),
]:
    ax.scatter(
        G, w, color="white", s=80, marker="o", edgecolors="k", linewidths=2, zorder=5
    )
    ax.annotate(
        label,
        (G, w),
        xytext=(G, w + dy),
        color="white",
        fontweight="bold",
        ha="center",
        path_effects=[path_effects.withStroke(linewidth=3, foreground="black")],
    )

# Trajectory
route = results.optimization.global_optimization.state_trajectory
ax.scatter(
    [s.coupling.FastLinearCoupling.G.value for s in route],
    [s.dynamics.w.value for s in route],
    color="white",
    s=10,
    marker="o",
    edgecolors="k",
    linewidths=0.5,
    zorder=4,
)
ax.set(xlabel="G", ylabel="w", title="Exploration")

# D-F: FC Matrices
for key, fc, title in [
    ("D", np.array(results.optimization.global_optimization.simulation.observations.fc), f"Global"),
    ("E", np.array(results.optimization.regional_optimization.simulation.observations.fc), f"Regional"),
    ("F", fc_target, "Target FC"),
]:
    ax = axes[key]
    fc_plot = np.copy(fc)
    np.fill_diagonal(fc_plot, np.nan)
    ax.imshow(fc_plot, cmap="cividis", vmin=0, vmax=0.9)
    ax.set(xticks=[], yticks=[], title=title)

# G-H: Scatter Plots
triu = np.triu_indices_from(fc_target, k=1)
for key, fc, title in [
    ("G", results.optimization.global_optimization.simulation.observations.fc, "Global Fit"),
    ("H", results.optimization.regional_optimization.simulation.observations.fc, "Regional Fit"),
]:
    ax = axes[key]
    ax.scatter(fc_target[triu], np.array(fc)[triu], alpha=0.3, s=8, color="royalblue")
    ax.plot([0, 1], [0, 1], "k--", lw=1.5)
    ax.set(
        xlabel="Empirical FC",
        ylabel="Simulated FC",
        title=f"r={fc_corr(fc, fc_target):.3f}",
        aspect="equal",
    )

# I: Fitted Regional Parameters (dual y-axis)
mean_conn = exp.network.weights.mean(axis=1)
opt_g, opt_r = results.global_optimization, results.regional_optimization

ax1 = axes["I"]

ax2 = ax1.twinx()

# Left axis: w (blue)
ax1.scatter(
    mean_conn,
    opt_r.fitted_params.dynamics.w.value.flatten(),
    alpha=0.7,
    s=5,
    color="royalblue",
    edgecolors="k",
    lw=0.5,
    label="w (regional)",
)
ax1.axhline(
    opt_g.fitted_params.dynamics.w.value,
    color="royalblue",
    ls="--",
    lw=2,
    label=f"w (global): {float(opt_g.fitted_params.dynamics.w.value):.3f}",
)
ax1.set_xlabel("Mean Connectivity")
ax1.set_ylabel("w", color="royalblue")
ax1.tick_params(axis="y", labelcolor="royalblue")

# Right axis: I_o (orange)
ax2.scatter(
    mean_conn,
    opt_r.fitted_params.dynamics.I_o.value.flatten(),
    alpha=0.7,
    s=5,
    color="darkorange",
    edgecolors="k",
    lw=0.5,
    label="I_o (regional)",
)
ax2.axhline(
    opt_g.fitted_params.dynamics.I_o,
    color="darkorange",
    ls="--",
    lw=2,
    label=f"I_o (global): {float(opt_g.fitted_params.dynamics.I_o):.3f}",
)
ax2.set_ylabel("I_o", color="darkorange")
ax2.tick_params(axis="y", labelcolor="darkorange")

ax1.set_title("Regional Parameters")
# lines1, labels1 = ax1.get_legend_handles_labels()
# lines2, labels2 = ax2.get_legend_handles_labels()
# ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right", fontsize=8)

plt.suptitle(
    "Reduced Wong-Wang BOLD FC Optimization", fontsize=14, fontweight="bold", y=1
)
bsplot.style.format_fig(fig)
##
Figure 1