Competing Objectives in Hopf Network Fitting: FC vs. Frequency Gradient

Optimizing two contradictory losses together acts as a biological regularizer

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

We fit a delay-coupled Hopf network to two targets at once: an empirical functional connectivity (FC) matrix and a spatial peak-frequency gradient. The point of this workflow is what happens when those targets disagree.

Optimized on its own, each target drives the model into a biologically implausible corner:

  • FC correlation alone is maximized by global oversynchronization: drive the coupling \(G\) up until every region phase-locks. FC correlation looks excellent, but the regional frequency structure collapses onto a single global rhythm.
  • The frequency-gradient term alone is maximized by decoupling: drive \(G \to 0\) so every node free-runs at its own natural frequency. The gradient is reproduced perfectly, but with no coupling there is no FC structure left.

The two objectives therefore pull toward opposite ends of the \(G\) axis. Optimizing them together acts as a biological regularizer: each term holds the other back from its degenerate optimum, forcing the fit into the intermediate, metastable regime where real cortex operates. The aim is a plausible fit, not a perfect one — a single Hopf node is a coarse caricature of a brain region, so some residual mismatch is expected by construction.

We first show the conflict directly on one starting point, then build the workflow around it in two stages:

  1. Genetic pre-search (NSGA-II) over \((G, a, \sigma)\). We skip searching \(\omega\) here: the natural frequency of an isolated Hopf node is just \(\omega/2\pi\), so we can initialize \(\omega\) directly from the per-region target peak frequencies. The GA’s job is then only to find global scalars that counteract the frequency-shift induced by coupling, while every region already starts from the correct natural frequency. \(\omega\) remains a parameter of the gradient stage.
  2. Parallel gradient optimization from every Pareto seed. We skip the per-region pretune step and instead launch one full multimodal optimization per non-dominated solution on the Pareto front, all running in parallel via ParallelExecution. A small wrapper turns the OptaxOptimizer.run loop into something that can be pmap/lax.map-ed, and it constructs the differentiable initial state from the scalar GA values that arrive batched through DataAxis.

At the end we compare how the different starting points performed: their starting GA metrics, their post-optimization metrics, and where each of them ended up on the FC vs. frequency-gradient trade-off.

Environment Setup and Imports
import os
cpu = True
if cpu:
    N = 8
    os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={N}'

import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.signal import spectrogram
import jax
import jax.numpy as jnp
import optax

jax.config.update("jax_enable_x64", True)

from tvboptim.types import Parameter, BoundedParameter, DataAxis
from tvboptim.types.spaces import Space
from tvboptim.execution import ParallelExecution
from tvboptim.utils import set_cache_path, cache
from tvboptim.optim.optax import OptaxOptimizer

from tvboptim.experimental.network_dynamics import Network, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import SupHopf
from tvboptim.experimental.network_dynamics.coupling import FastLinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.data import load_structural_connectivity, load_functional_connectivity

from tvboptim.observations.tvb_monitors.bold import Bold
from tvboptim.observations.observation import compute_fc, fc_corr

import pandas as pd
from pymoo.core.problem import Problem
from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.optimize import minimize as pymoo_minimize
from pymoo.indicators.hv import HV

set_cache_path("./hopf_pareto")

Loading Data and Defining Target Frequency Gradient

weights, lengths, region_labels = load_structural_connectivity(name="dk_average")
fc_target = np.asarray(load_functional_connectivity(name="dk_average"))

weights = weights / np.max(weights)
n_nodes = weights.shape[0]

speed = 3.0
delays = lengths / speed

region_labels = np.array(region_labels)
idx_l = np.where(region_labels == "L.LOG")[0]
idx_r = np.where(region_labels == "R.LOG")[0]
dist_from_vc = np.array(np.squeeze(0.5 * (lengths[idx_l, :] + lengths[idx_r, :])))

f_min, f_max = 7.0, 11.0
min_dist, max_dist = dist_from_vc.min(), dist_from_vc.max()
peak_freqs_target = f_max - (f_max - f_min) / (max_dist - min_dist) * (dist_from_vc - min_dist)

# Per-region natural frequency in Hopf radians. With dt in ms and frequency in
# Hz, omega = 2*pi*f / 1000. Used as the initialization for omega in both the
# GA and gradient stages so that every region already starts from the right
# natural frequency.
omega_target = 2 * np.pi * peak_freqs_target / 1000.0

Building the Hopf Network

row_sums = weights.sum(axis=1)
weights_laplacian = weights - np.diag(row_sums)

graph = DenseDelayGraph(weights_laplacian, delays, region_labels=region_labels)
dynamics = SupHopf(a=0.01, omega=jnp.asarray(omega_target))
coupling = FastLinearCoupling(local_states=["x"], G=0.025)
noise = AdditiveNoise(sigma=1e-2, apply_to=["x", "y"])

network = Network(
    dynamics=dynamics,
    coupling={"instant": coupling},
    graph=graph,
    noise=noise,
)

Preparing and Running the Simulation

Show simulation prep
t1 = 60_000   # ms
dt = 1.0
fs = 1000.0 / dt
solver = Heun()

model, state = prepare(network, solver, t1=t1, dt=dt)
result_init = model(state)
network.update_history(result_init)
model, state = prepare(network, solver, t1=t1, dt=dt)

bold_monitor = Bold(period=1000.0, downsample_period=dt, voi=0, history=result_init)
/tmp/ipykernel_4890/3080740313.py:1: UserWarning: Solution has 84 nodes; plotting first 10. Pass nodes=[...] or max_nodes=N to override.
  result_init.plot(default_window = 300, dpi = 120)
(<Figure size 1200x480 with 2 Axes>,
 [<Axes: title={'center': 'x'}, ylabel='x'>,
  <Axes: title={'center': 'y'}, xlabel='time', ylabel='y'>])

Defining Observations and Loss

The loss adds the two competing objectives, each written as a negative correlation so that lower is better:

\[\mathcal{L} = -\,\rho_{\mathrm{FC}} \;-\; \rho_{\text{freq-grad}}\]

The first term, fc_loss, rewards matching the empirical FC matrix; on its own it is minimized by global synchronization (large \(G\)). The second, fgrad_corr, rewards reproducing the spatial peak-frequency gradient — the Pearson \(\rho\) between the soft-argmax peak frequencies and the target gradient; on its own it is minimized by decoupling (\(G \to 0\)). Carrying both terms is the regularization: neither can reach its degenerate optimum without paying in the other.

fs_obs = 100.0
subsample = int(fs / fs_obs)

WELCH_NPERSEG = 512
WELCH_NFFT    = 2048

def simulate(state):
    result = model(state)
    bold = bold_monitor(result)
    fc = compute_fc(bold, skip_t=20)
    x_sub = result.data[::subsample, 0, :]
    _, psd = jax.scipy.signal.welch(x_sub.T, fs=fs_obs,
                                    nperseg=WELCH_NPERSEG, nfft=WELCH_NFFT)
    return result, bold, fc, psd

def soft_peak_freqs(psd, transient=3, beta=150.0):
    f_eval = jnp.linspace(0, fs_obs / 2, psd.shape[1])
    psd_t = psd[:, transient:]
    f_t = f_eval[transient:]
    psd_norm = psd_t / (jnp.max(psd_t, axis=1, keepdims=True) + 1e-12)
    w = jax.nn.softmax(beta * psd_norm, axis=1)
    return jnp.sum(w * f_t[None, :], axis=1)

peak_freqs_uniform = 9.0 * jnp.ones(n_nodes)

def loss(state):
    _, _, fc, psd = simulate(state)
    fc_loss = -fc_corr(fc, fc_target)
    fitted_peaks = soft_peak_freqs(psd)
    fgrad_corr = jnp.corrcoef(fitted_peaks, peak_freqs_target)[0, 1]
    return fc_loss - fgrad_corr

The Conflict, Demonstrated

Before building the full workflow, we test the intro’s claim directly. From one moderate starting point we run three short optimizations of 50 Adam steps each: one on the FC term alone, one on the frequency-gradient term alone, one on the combined loss. Then we see where each lands.

The starting point is a hand-picked moderate-coupling state, with \(G\), \(a\), and \(\sigma\) close to a mid-front Pareto solution from Stage 0 but hardcoded here so the demonstration does not depend on the GA. All three runs free the same parameters as the main workflow: \(G\) and per-region \(a\) and \(\omega\). Because \(\omega\) is free, the gradient-only run can retune per-region frequencies without lowering \(G\), so its pull may come out weaker than the FC-only run’s.

Show ablation setup
LR               = 0.001
N_ABLATION_STEPS = 50

def loss_fc_only(state):
    _, _, fc, _ = simulate(state)
    return -fc_corr(fc, fc_target)

def loss_grad_only(state):
    _, _, _, psd = simulate(state)
    fitted_peaks = soft_peak_freqs(psd)
    return -jnp.corrcoef(fitted_peaks, peak_freqs_target)[0, 1]

# Combined-loss optimizer; reused unchanged by the Stage 1 parallel sweep.
optimizer      = OptaxOptimizer(loss, optax.adam(LR))
optimizer_fc   = OptaxOptimizer(loss_fc_only, optax.adam(LR))
optimizer_grad = OptaxOptimizer(loss_grad_only, optax.adam(LR))

# A moderate-coupling starting point. These scalars are close to the
# median-`combined` Pareto seed that Stage 0 finds, but are hardcoded here so
# the demonstration runs before the genetic search and independently of it.
motivation_seed = {"G": 0.15, "a": 0.005, "sigma": 0.031}

Run experiment:

Show ablation run
def _ablation_metrics(eval_state):
    _, _, fc_e, psd_e = simulate(eval_state)
    peaks_e = soft_peak_freqs(psd_e)
    return (float(fc_corr(fc_e, fc_target)),
            float(jnp.corrcoef(peaks_e, peak_freqs_target)[0, 1]))

def run_one_ablation(opt):
    """Optimize the motivation seed under `opt` for N_ABLATION_STEPS, omega free."""
    abl_state = copy.deepcopy(state)
    abl_state.coupling.instant.G = Parameter(jnp.asarray(motivation_seed["G"]))
    abl_state.dynamics.a         = Parameter(motivation_seed["a"] * jnp.ones(n_nodes))
    abl_state.dynamics.omega     = BoundedParameter(jnp.asarray(omega_target),
                                                    0.0, jnp.inf)
    abl_state.noise.sigma        = jnp.asarray(motivation_seed["sigma"])
    final_state, _ = opt.run(abl_state, max_steps=N_ABLATION_STEPS,
                             chunk_size=N_ABLATION_STEPS)
    fc_c, grad_c = _ablation_metrics(final_state)
    return {"fc_corr": fc_c, "freq_corr": grad_c,
            "G": float(jnp.asarray(final_state.coupling.instant.G))}

@cache("ablation_motivation", redo=False)
def run_ablation():
    # Start point: the motivation seed, omega at the target gradient.
    start_state = copy.deepcopy(state)
    start_state.coupling.instant.G = jnp.asarray(motivation_seed["G"])
    start_state.dynamics.a         = jnp.asarray(motivation_seed["a"])
    start_state.noise.sigma        = jnp.asarray(motivation_seed["sigma"])
    start_fc, start_grad = _ablation_metrics(start_state)
    return {
        "start":     {"fc_corr": start_fc, "freq_corr": start_grad},
        "fc_only":   run_one_ablation(optimizer_fc),
        "grad_only": run_one_ablation(optimizer_grad),
        "combined":  run_one_ablation(optimizer),
    }

ablation = run_ablation()

print(f"Motivation seed: G={motivation_seed['G']}, "
      f"a={motivation_seed['a']}, sigma={motivation_seed['sigma']}")
for k in ["start", "fc_only", "grad_only", "combined"]:
    v = ablation[k]
    extra = "" if k == "start" else f", G={v['G']:.4f}"
    print(f"  {k:10s}: FC corr={v['fc_corr']:.3f}, "
          f"freq corr={v['freq_corr']:.3f}{extra}")
Loading ablation_motivation from cache, last modified 2026-05-22 15:21:58.648658
Motivation seed: G=0.15, a=0.005, sigma=0.031
  start     : FC corr=0.218, freq corr=0.283
  fc_only   : FC corr=0.756, freq corr=-0.108, G=0.1937
  grad_only : FC corr=0.137, freq corr=0.807, G=0.1265
  combined  : FC corr=0.455, freq corr=0.729, G=0.1792
Show plotting code
start = ablation["start"]
sx, sy = 1.0 - start["freq_corr"], 1.0 - start["fc_corr"]

runs = [("FC only",       "fc_only",   "tab:red"),
        ("Gradient only", "grad_only", "tab:green"),
        ("Combined",      "combined",  "tab:blue")]

fig, ax = plt.subplots(figsize=(6, 5.5))
for label, key, color in runs:
    r = ablation[key]
    ex, ey = 1.0 - r["freq_corr"], 1.0 - r["fc_corr"]
    ax.annotate("", xy=(ex, ey), xytext=(sx, sy),
                arrowprops=dict(arrowstyle="->", color=color, lw=2.0,
                                shrinkA=6, shrinkB=6))
    ax.scatter([ex], [ey], color=color, s=170, marker="*",
               edgecolors="black", linewidths=0.6, zorder=3, label=label)
ax.scatter([sx], [sy], color="black", s=110, marker="o",
           zorder=4, label="Starting seed")
ax.set_xlabel(r"$1 - \rho_{\text{freq-grad}}$")
ax.set_ylabel(r"$1 - \rho_{FC}$")
ax.set_title(f"Where each loss pulls the fit ({N_ABLATION_STEPS} Adam steps)")
ax.grid(alpha=0.3)
ax.legend(loc="best", fontsize=9)
plt.tight_layout()
plt.show()
Figure 1: Each loss pulls the fit a different way. Three 50-step Adam optimizations from one moderate-coupling seed (black circle), each minimizing a different loss; stars mark the endpoints. Lower-left is good on both objectives. All three free \(G\) and per-region \(a\) and \(\omega\).

The rest of the workflow handles this trade-off systematically — Stage 0 maps the whole front, Stage 1 optimizes from every point on it.

Stage 0: Genetic Pre-Search (NSGA-II)

A small NSGA-II search over \((G, a, \log_{10}\sigma)\) with two objectives: PSD MAE to 9 Hz and \(1 - \rho_{\mathrm{FC}}\). \(\omega\) is not searched; every region keeps its target natural frequency, set above as omega_target.

def ga_evaluate(state):
    _, _, fc, psd = simulate(state)
    fitted_peaks = soft_peak_freqs(psd)
    # Stage 0's frequency objective is PSD MAE to a *uniform* 9 Hz, not to
    # the spatial gradient: it only asks the global scalars to keep every
    # region oscillating in the alpha band. Reproducing the per-region
    # gradient is left to the Stage 1 gradient loss (`peak_freqs_target`).
    psd_mae = jnp.mean(jnp.abs(fitted_peaks - peak_freqs_uniform))
    fc_correlation = fc_corr(fc, fc_target)
    return {"psd_mae": psd_mae, "fc_correlation": fc_correlation}

GA_PARAM_RANGES = {
    "G":          [0.001,  0.15],
    "a":          [-0.1,   0.01],
    "log10_sigma": [-2.0, -0.5],
}
GA_POP_SIZE = N * 1
GA_N_GEN = 40

M1_SCALE = 9.0
GA_REF_POINT = np.array([1.5 * M1_SCALE, 1.5])
class HopfGAProblem(Problem):
    """NSGA-II problem: search (G, a, log10_sigma) on (PSD MAE, 1 - FC corr).

    Omega is *not* searched. Every region keeps its target natural frequency
    baked into `base_state` (via SupHopf), so the GA only varies the three
    global scalars. Each population is evaluated in parallel across
    `n_devices` pmap shards. Per-evaluation metrics are appended to
    `eval_log` so the caller can inspect them after the run.
    """

    def __init__(self, base_state, n_devices, eval_log):
        # Decision-variable bounds; column order is [G, a, log10(sigma)].
        xl = np.array([GA_PARAM_RANGES["G"][0],
                       GA_PARAM_RANGES["a"][0],
                       GA_PARAM_RANGES["log10_sigma"][0]])
        xu = np.array([GA_PARAM_RANGES["G"][1],
                       GA_PARAM_RANGES["a"][1],
                       GA_PARAM_RANGES["log10_sigma"][1]])
        super().__init__(n_var=3, n_obj=2, xl=xl, xu=xu)
        self.base_state = base_state
        self.n_devices  = n_devices
        self.eval_log   = eval_log

    def _evaluate(self, X, out, *args, **kwargs):
        # Decode the third decision column out of log10-space.
        sigma_phys = 10.0 ** X[:, 2]

        # Build a batched state: each population row rides a DataAxis on
        # G / a / sigma. omega is left untouched in base_state, where it
        # already holds the per-region target natural frequency.
        batch_state = copy.deepcopy(self.base_state)
        batch_state.coupling.instant.G = DataAxis(jnp.asarray(X[:, 0]))
        batch_state.dynamics.a         = DataAxis(jnp.asarray(X[:, 1]))
        batch_state.noise.sigma        = DataAxis(jnp.asarray(sigma_phys))

        # zip-mode walks the three DataAxes together (one config per row),
        # which is what we want for a flat population sweep.
        space     = Space(batch_state, mode="zip")
        execution = ParallelExecution(ga_evaluate, space, n_pmap=self.n_devices)
        results   = execution.run()

        # Pack objectives + log every evaluation for the post-mortem plots.
        F = np.empty((len(results), 2))
        for i, r in enumerate(results):
            psd_mae = float(r["psd_mae"])
            fc_c    = float(r["fc_correlation"])
            F[i] = (psd_mae, 1.0 - fc_c)
            self.eval_log.append({
                "G":              float(X[i, 0]),
                "a":              float(X[i, 1]),
                "sigma":          float(sigma_phys[i]),
                "psd_mae":        psd_mae,
                "fc_correlation": fc_c,
            })

        # NaNs would poison non-dominated sorting; clip them out.
        out["F"] = np.nan_to_num(F, nan=1e6, posinf=1e6, neginf=1e6)


def _make_hv_callback(hv_log):
    """Build an NSGA-II callback that logs hypervolume per generation."""
    def cb(algorithm):
        F = algorithm.pop.get("F")
        try:
            hv = HV(ref_point=GA_REF_POINT)(F)
        except Exception:
            hv = np.nan
        hv_log.append(hv)
        print(f"[Gen {algorithm.n_gen}] HV={hv:.4f} | ND={len(F)}")
    return cb


@cache("ga_hopf_pareto", redo=False)
def run_ga(pop_size=GA_POP_SIZE, n_gen=GA_N_GEN):
    eval_log = []
    hv_log   = []

    res = pymoo_minimize(
        problem=HopfGAProblem(base_state=state, n_devices=N, eval_log=eval_log),
        algorithm=NSGA2(pop_size=pop_size),
        termination=("n_gen", n_gen),
        callback=_make_hv_callback(hv_log),
        seed=42,
        verbose=True,
    )
    return {
        "pareto_X":   np.asarray(res.X),
        "pareto_F":   np.asarray(res.F),
        "all_evals":  eval_log,
        "hv_per_gen": np.asarray(hv_log),
    }

ga_result = run_ga()

pareto_X = ga_result["pareto_X"].copy()
pareto_X[:, 2] = 10.0 ** pareto_X[:, 2]   # decode log10(sigma) → physical sigma
pareto_F = ga_result["pareto_F"]
all_evals_df = pd.DataFrame(ga_result["all_evals"])
pareto_df = pd.DataFrame(
    np.hstack([pareto_X, pareto_F]),
    columns=["G", "a", "sigma", "psd_mae", "one_minus_fc"],
)
pareto_df["fc_correlation"] = 1.0 - pareto_df["one_minus_fc"]
pareto_df["combined"] = 0.5 * (pareto_df["psd_mae"] / M1_SCALE) + 0.5 * pareto_df["one_minus_fc"]
n_pareto = len(pareto_df)
print(f"Pareto front size: {n_pareto}")

Pareto Front and Convergence

Show plotting code
fig, (ax_pf, ax_hv) = plt.subplots(1, 2, figsize=(11, 4))

ax_pf.scatter(all_evals_df["psd_mae"], 1.0 - all_evals_df["fc_correlation"],
              c="lightgray", s=25, alpha=0.6, label="All evaluations")
sc = ax_pf.scatter(pareto_df["psd_mae"], pareto_df["one_minus_fc"],
                   c=pareto_df["combined"], cmap="viridis_r",
                   s=80, edgecolors="black", linewidths=0.8, label="Pareto front")
ax_pf.set_xlabel("PSD MAE to 9 Hz")
ax_pf.set_ylabel(r"$1 - \rho_{FC}$")
ax_pf.set_title("Pareto Front (Stage 0)")
ax_pf.grid(alpha=0.3)
ax_pf.legend(loc="best", fontsize=8)
plt.colorbar(sc, ax=ax_pf, label="Combined (0.5/0.5)")

ax_hv.plot(np.arange(1, len(ga_result["hv_per_gen"]) + 1),
           ga_result["hv_per_gen"], marker="o")
ax_hv.set_xlabel("Generation")
ax_hv.set_ylabel("Hypervolume")
ax_hv.set_title("HV convergence")
ax_hv.grid(alpha=0.3)

plt.tight_layout()
plt.show()
Figure 2: Stage 0 NSGA-II output. Pareto front in objective space (PSD MAE vs \(1 - \rho_{FC}\)); grey dots are all evaluated solutions, colored dots are the non-dominated front. Right panel: hypervolume convergence.

Stage 1: Parallel Gradient Optimization From Every Pareto Seed

Instead of selecting one Pareto solution and pretuning per-region \(\omega/a\) on a short simulation, we treat every Pareto front entry as an independent starting point for the full multimodal optimization, and run them all in parallel via ParallelExecution.

Inside the wrapper we wrap the three optimization targets in Parameter via direct attribute assignment — the same pattern Stage 0 used for DataAxis. The optimizer is built once at module scope (OptaxOptimizer.run is pure: it re-partitions the state and re-inits optimizer state on every call). Per-seed we return just the optimized parameter values plus the scalar metrics; reconstructing the best-seed state for diagnostics is a one-liner downstream.

N_OPT_STEPS = 200   # gradient steps per seed; one fused jax.lax.scan

# `optimizer` (combined loss, adam at LR) was built in the motivation section.

def optimize_from_seed(seed_state):
    """Run one full gradient optimization seeded by a single GA Pareto solution.

    `seed_state` arrives with DataAxis already sliced: G, a, sigma are
    scalars (one per Pareto row), and omega is the (n_nodes,) target
    frequency carried unchanged from `state`. We wrap the three
    optimization targets in Parameter, hand the state to OptaxOptimizer.run
    (chunk_size == max_steps fuses the entire run into one lax.scan), then
    re-simulate to score it.
    """
    G_val     = jnp.asarray(seed_state.coupling.instant.G)
    a_val     = jnp.asarray(seed_state.dynamics.a)
    omega_per = jnp.asarray(seed_state.dynamics.omega)

    seed_state.coupling.instant.G = Parameter(G_val)
    seed_state.dynamics.a         = Parameter(a_val * jnp.ones(n_nodes))
    seed_state.dynamics.omega     = BoundedParameter(omega_per, 0.0, jnp.inf)

    final_state, _ = optimizer.run(
        seed_state, max_steps=N_OPT_STEPS, chunk_size=N_OPT_STEPS,
    )

    _, _, fc_f, psd_f = simulate(final_state)
    final_peaks     = soft_peak_freqs(psd_f)
    final_fc_corr   = fc_corr(fc_f, fc_target)
    final_freq_corr = jnp.corrcoef(final_peaks, peak_freqs_target)[0, 1]
    # jnp.asarray on a BoundedParameter triggers __jax_array__, so
    # omega comes out already in its constrained physical-radians form.
    return {
        "final_loss":      -final_fc_corr - final_freq_corr,
        "final_fc_corr":   final_fc_corr,
        "final_freq_corr": final_freq_corr,
        "final_psd_mae":   jnp.mean(jnp.abs(final_peaks - peak_freqs_uniform)),
        "G_final":         jnp.asarray(final_state.coupling.instant.G),
        "a_final":         jnp.asarray(final_state.dynamics.a),
        "omega_final":     jnp.asarray(final_state.dynamics.omega),
    }

Running the parallel sweep

We feed every Pareto seed as a row of a zip-mode Space over DataAxis-wrapped batches, exactly the same pattern Stage 0 used to evaluate populations.

@cache("parallel_pareto_opt", redo=False)
def run_parallel_pareto_opt():
    seed_state = copy.deepcopy(state)
    seed_state.coupling.instant.G = DataAxis(jnp.asarray(pareto_df["G"].to_numpy()))
    seed_state.dynamics.a         = DataAxis(jnp.asarray(pareto_df["a"].to_numpy()))
    seed_state.noise.sigma        = DataAxis(jnp.asarray(pareto_df["sigma"].to_numpy()))
    # omega stays as the per-region target frequency from `state` — not batched.

    space = Space(seed_state, mode="zip")
    execution = ParallelExecution(optimize_from_seed, space, n_pmap=N)
    results = execution.run()

    rows = [{
        "final_loss":      float(results[i]["final_loss"]),
        "final_fc_corr":   float(results[i]["final_fc_corr"]),
        "final_freq_corr": float(results[i]["final_freq_corr"]),
        "final_psd_mae":   float(results[i]["final_psd_mae"]),
        "G_final":         float(np.asarray(results[i]["G_final"])),
        "a_final":         np.asarray(results[i]["a_final"]),
        "omega_final":     np.asarray(results[i]["omega_final"]),
    } for i in range(len(results))]
    return pd.DataFrame(rows)

opt_df = run_parallel_pareto_opt()
Show summary-table assembly
# Merge GA seed table with optimization outcomes for a single overview frame.
summary_df = pd.concat(
    [pareto_df.reset_index(drop=True),
     opt_df.reset_index(drop=True).add_suffix("")],
    axis=1,
)
summary_df["seed_idx"] = np.arange(len(summary_df))

# Parameter-drift columns: the gradient stage turns `a` and `omega` into
# per-region vectors, so summarize the spread `a` picked up and how far the
# fitted per-region omega moved from its target-gradient initialization (Hz).
omega_target_arr = np.asarray(omega_target)
summary_df["a_fit_mean"] = summary_df["a_final"].apply(lambda v: float(np.asarray(v).mean()))
summary_df["a_fit_std"]  = summary_df["a_final"].apply(lambda v: float(np.asarray(v).std()))
summary_df["domega_mae_hz"] = summary_df["omega_final"].apply(
    lambda v: float(np.mean(np.abs((np.asarray(v) - omega_target_arr) * 1000.0 / (2 * np.pi))))
)

Overview: How Did the Different Starting Points Perform?

A single table per seed: the GA seed scalars, the gradient-optimized parameters (including the per-region spread the gradient stage gave \(a\) and how far \(\omega\) drifted from its target-gradient initialization), and the before/after fit metrics.

 seed_idx      G      a  sigma  G_final  a_fit_mean  a_fit_std  domega_mae_hz  psd_mae  final_psd_mae  fc_correlation  final_fc_corr  final_freq_corr
        0 0.1425 0.0095 0.0105   0.1332      0.0031     0.0152         1.2488   0.0786         3.0963          0.0421         0.2674           0.7473
        1 0.1499 0.0100 0.0608   0.2426      0.0232     0.0400         4.9028   3.4770         3.4105          0.2588         0.6483           0.7085
        2 0.1435 0.0053 0.0211   0.1968      0.0018     0.0321         3.8025   2.0441         3.1103          0.1584         0.6589           0.8996
        3 0.1473 0.0097 0.0173   0.1886     -0.0041     0.0293         3.6779   1.0136         3.5584          0.1123         0.6262           0.9187
        4 0.1499 0.0100 0.0501   0.2382      0.0156     0.0417         4.7697   3.0986         3.5601          0.2548         0.6450           0.7910
        5 0.1500 0.0080 0.0309   0.2361      0.0071     0.0273         4.0419   2.3750         3.0746          0.2119         0.6626           0.8004
        6 0.1499 0.0052 0.0309   0.2230      0.0049     0.0316         4.1407   2.5672         3.1037          0.2172         0.6479           0.8316
        7 0.1497 0.0095 0.0148   0.1708      0.0013     0.0204         2.7886   0.3885         3.5292          0.0912         0.4981           0.8334
Show plotting code
final_one_minus_fc = 1.0 - summary_df["final_fc_corr"].to_numpy()
init_one_minus_fc  = summary_df["one_minus_fc"].to_numpy()
init_psd_mae       = summary_df["psd_mae"].to_numpy()
final_psd_mae      = summary_df["final_psd_mae"].to_numpy()
init_G             = summary_df["G"].to_numpy()
final_G            = summary_df["G_final"].to_numpy()

# Shared color scale so the GA-seed G (circles) and optimized G (stars)
# are directly comparable. The max is capped at the GA-seed G range so the
# gradient across the initial Pareto seeds (circles) stays visible;
# optimized G values above this cap (stars) clip to the top color.
norm = plt.Normalize(vmin=init_G.min(),
                     vmax=0.165)
cmap = "plasma"

fig, ax = plt.subplots(figsize=(7.5, 5))
ax.scatter(all_evals_df["psd_mae"], 1.0 - all_evals_df["fc_correlation"],
           c="lightgray", s=18, alpha=0.5, label="GA evaluations")
for i in range(len(summary_df)):
    ax.annotate("",
                xy=(final_psd_mae[i], final_one_minus_fc[i]),
                xytext=(init_psd_mae[i], init_one_minus_fc[i]),
                arrowprops=dict(arrowstyle="->", color="grey", alpha=0.55,
                                shrinkA=2, shrinkB=4))
ax.scatter(init_psd_mae, init_one_minus_fc, c=init_G, cmap=cmap, norm=norm,
           s=70, marker="o", edgecolors="black", linewidths=1.0,
           label="GA Pareto starts")
sc = ax.scatter(final_psd_mae, final_one_minus_fc, c=final_G, cmap=cmap, norm=norm,
                s=170, marker="*", edgecolors="black", linewidths=0.6,
                label="After parallel opt")
ax.set_xlabel("PSD MAE to 9 Hz")
ax.set_ylabel(r"$1 - \rho_{FC}$")
ax.set_xlim(0.0, 10)
ax.grid(alpha=0.3)
ax.legend(loc="best", fontsize=8)
plt.colorbar(sc, ax=ax, label=r"Coupling strength $G$", extend="max")
plt.tight_layout()
plt.show()
Figure 3: Where each seed landed after gradient optimization. Circles: GA Pareto front (starting objectives: PSD MAE to 9 Hz vs \(1 - \rho_{FC}\)). Stars: post-optimization objectives. Arrows connect each seed’s start to its end. Color encodes the coupling strength \(G\) on a shared scale — circles show the GA seed value, stars the gradient-optimized value. The color scale is capped so the gradient across the GA seeds stays visible; \(G\) values above the cap clip to the top color (arrow on the colorbar).

Diagnostics for the Best Seed

We pick the seed with the lowest combined post-optimization loss and re-simulate its final state for a full diagnostic panel.

Show best-seed reconstruction
final_losses = summary_df["final_loss"].to_numpy()
best_i = int(np.nanargmin(np.where(np.isfinite(final_losses), final_losses, np.nan)))
best_row = summary_df.iloc[best_i]
print(f"Best seed index: {best_i}")
print(f"  GA start  : G={best_row['G']:.4f}, a={best_row['a']:.4f}, "
      f"sigma={best_row['sigma']:.4f}  (omega initialized from per-region target)")
print(f"  GA metrics: PSD MAE={best_row['psd_mae']:.4f}, FC corr={best_row['fc_correlation']:.4f}")
print(f"  Final     : PSD MAE={best_row['final_psd_mae']:.4f}, FC corr={best_row['final_fc_corr']:.4f}, "
      f"freq corr={best_row['final_freq_corr']:.4f}")

best_state = copy.deepcopy(state)
best_state.coupling.instant.G = jnp.asarray(best_row["G_final"])
best_state.dynamics.a         = jnp.asarray(best_row["a_final"])
best_state.dynamics.omega     = jnp.asarray(best_row["omega_final"])
result_best, bold_best, fc_best, psd_best = simulate(best_state)
fc_best   = np.asarray(fc_best)
psd_best  = np.asarray(psd_best)
ts_best   = np.asarray(result_best.data)
bold_best = np.asarray(bold_best.data)

f_eval = np.linspace(0, fs_obs / 2, psd_best.shape[1])
transient = 3
peak_freqs_best = f_eval[transient:][np.argmax(psd_best[:, transient:], axis=1)]
mask_best = peak_freqs_best > 0
Show plotting code
import scipy.stats

fig = plt.figure(figsize=(11, 7.5))
gs = gridspec.GridSpec(2, 5, height_ratios=[1, 1])

ax_ts    = fig.add_subplot(gs[0, 0:3])
ax_fgrad = fig.add_subplot(gs[0, 3:])
ax_bold  = fig.add_subplot(gs[1, 0:3])
ax_fc    = fig.add_subplot(gs[1, 3:])

idx_node = 5   # region highlighted across panels a-c

def _panel_letter(ax, letter):
    ax.set_title(letter, loc="left", fontweight="bold", fontsize=12)

# --- a: spectrogram of the highlighted region ---
ts_data  = ts_best[-30000:, 0, idx_node]
nperseg  = 128 * subsample
noverlap = nperseg - 8 * subsample   # ~94% overlap → many, smoothly-spaced time slices
nfft     = 2 * nperseg               # zero-pad → finer, smoother frequency axis
f_sp, t_sp, Sxx = spectrogram(ts_data, fs=fs, nperseg=nperseg,
                              noverlap=noverlap, nfft=nfft)
Sxx = Sxx / np.max(Sxx)                              
im_spec = ax_ts.pcolormesh(t_sp, f_sp, Sxx, shading='gouraud', cmap='viridis')
ax_ts.set_ylim(0, 20)
ax_ts.set_xlabel('Time [s]')
ax_ts.set_ylabel('Frequency [Hz]')
ax_ts.set_title('Spectrogram (highlighted region)')
plt.colorbar(im_spec, ax=ax_ts, label='Power [a.u.]', shrink=0.85)
_panel_letter(ax_ts, "a")

# --- b: fitted vs target peak frequency, square axes with identity line ---
freq_lo, freq_hi = 5.5, 12.0
pr_f = np.corrcoef(peak_freqs_target[mask_best], peak_freqs_best[mask_best])[0, 1]
sr_f = scipy.stats.spearmanr(peak_freqs_target[mask_best], peak_freqs_best[mask_best])[0]
ax_fgrad.plot([freq_lo, freq_hi], [freq_lo, freq_hi], 'k--', lw=1, zorder=1)
ax_fgrad.scatter(peak_freqs_target[mask_best], peak_freqs_best[mask_best],
                 s=40, color="grey", edgecolors="white", linewidths=0.5,
                 zorder=2, label="region")
ax_fgrad.scatter(peak_freqs_target[mask_best][idx_node], peak_freqs_best[mask_best][idx_node],
                 color="royalblue", marker="s", s=120, edgecolors="k",
                 zorder=3, label="highlighted region")
ax_fgrad.set_xlim(freq_lo, freq_hi)
ax_fgrad.set_ylim(freq_lo, freq_hi)
ax_fgrad.set_aspect("equal")
ax_fgrad.set_xlabel("Target peak frequency [Hz]")
ax_fgrad.set_ylabel("Fitted peak frequency [Hz]")
ax_fgrad.set_title("Peak frequency")
ax_fgrad.legend(loc="upper left", fontsize=8, framealpha=0.9)
ax_fgrad.text(0.96, 0.04,
              rf"$\rho_P = {pr_f:.2f}$" + "\n" + rf"$\rho_S = {sr_f:.2f}$",
              transform=ax_fgrad.transAxes, ha="right", va="bottom", fontsize=9,
              bbox=dict(boxstyle="round,pad=0.3", edgecolor="0.6", facecolor="white"))
ax_fgrad.grid(alpha=0.3)
_panel_letter(ax_fgrad, "b")

# --- c: BOLD time courses ---
ax_bold.plot(bold_best[-50:, 0, :10], color="grey", alpha=0.5)
ax_bold.plot(bold_best[-50:, 0, idx_node], color="royalblue", linewidth=1.5,
             label='highlighted region')
ax_bold.set_xlabel("Time [TR]")
ax_bold.set_ylabel("BOLD [a.u.]")
ax_bold.set_title('BOLD time courses')
ax_bold.legend(loc="upper right", fontsize=8)
_panel_letter(ax_bold, "c")

# --- d: functional connectivity, target vs estimated ---
fc_dual = np.tril(fc_target) + np.triu(fc_best, k=1)
im_fc = ax_fc.imshow(fc_dual, cmap='cividis', vmin=0, vmax=1.0)
ax_fc.set_title(rf"FC ($\rho_P = {fc_corr(fc_best, fc_target):.3f}$)")
ax_fc.set_xlabel("Target FC (lower) / Estimated FC (upper)")
plt.colorbar(im_fc, ax=ax_fc, shrink=0.85)
_panel_letter(ax_fc, "d")

fig.tight_layout()
plt.show()
Figure 4: Diagnostics for the best seed. a: Spectrogram of the highlighted region — note that the instantaneous peak frequency drifts and jumps over time, a coupling-driven network effect rather than a fixed single-region rhythm. b: Fitted vs. target peak frequency per region; the dashed line is identity, \(\rho_P\) / \(\rho_S\) are the Pearson / Spearman correlations. Because each fitted value is a soft-argmax collapsed over the whole recording, this time-averaging masks the wandering seen in panel a — the apparent tight clustering along the identity line is therefore weaker than it looks. c: BOLD time courses, highlighted region in blue. d: Target FC (lower triangle) vs. estimated FC (upper triangle).

Per-Region \(\omega\) Drift for the Best Seed

The gradient stage turns \(\omega\) into a per-region vector (the per-seed drift summary is in the overview table above). For the best seed, the plot below shows how far each region’s fitted frequency moved from its target-gradient initialization.

Show plotting code
omega_init_hz  = peak_freqs_target
omega_final_hz = np.asarray(best_row["omega_final"]) * 1000.0 / (2 * np.pi)

order_f = np.argsort(omega_init_hz)
xs = np.arange(n_nodes)

fig, ax = plt.subplots(figsize=(9, 3.5))
ax.scatter(xs, omega_init_hz[order_f], facecolors="none", edgecolors="black",
           s=45, linewidths=1.0, label="target")
ax.scatter(xs, omega_final_hz[order_f], c="tab:blue", s=45, label="fitted")
ax.set_xlabel("Region (sorted by target frequency)")
ax.set_ylabel("Peak frequency [Hz]")
ax.legend(fontsize=8)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Figure 5: Per-region \(\omega\) drift for the best seed. Initial target frequency (open) vs gradient-optimized frequency (filled), in Hz. Regions are sorted by target frequency.

Summary

Seeds from across the Pareto front, high-FC and high-gradient extremes alike, converged into the same intermediate region after the parallel gradient stage. That convergence is the result: the combined loss has a single basin, and it sits where neither objective alone would land.

The fit is plausible, not exact. A single Hopf node carries one amplitude and one frequency, with no laminar structure and no E/I balance, so residual mismatch in both FC and the gradient is expected. The combined loss holds the model in that intermediate regime instead of letting either metric run to its degenerate optimum.