Stimulation with Bayesian Inference

Recovering Stimulus and Excitability from a Degenerate Forward Model

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

A pulse stimulus is delivered to a Generic2dOscillator and we record a noisy time series. From that one recording, we want to recover two things: how strong the stimulus was, and how excitable the underlying system is. The problem is that these two parameters trade off. A weaker stimulus into a more excitable system produces almost the same trajectory as a stronger stimulus into a less excitable one.

Gradient descent picks an arbitrary point on this ridge of equally good fits and reports it as the answer. Bayesian inference reports the whole ridge, and prior knowledge about either parameter (a device log, a PET scan) shifts the posterior onto a specific segment of the ridge, with the prior’s strength controlling how tight that segment is.

We run the example at a single node so the documentation renders in CI; the same construction applies node-by-node to a network.

WarningWhat this notebook does not claim
  • Single-region by construction. The forward model, prior, and posterior all live at one node. The extension to networks is mechanical (per-node parameters as DataAxis), but rendering a multi-region MCMC fit would dominate CI time.
  • The likelihood width is set deliberately wider than the data noise (LIKELIHOOD_SIGMA = 0.2 vs OBS_NOISE_SIGMA = 0.1). This is a didactic choice that keeps the likelihood flat along the ridge so the prior choice is visibly responsible for the difference between scenarios. With matched σ the posterior collapses harder onto the truth and the steering effect is harder to see. The next section unpacks this.
  • Priors here stand in for evidence sources. We do not pull from a real PET map. Replacing dist.Normal(PRIOR_AMP_MEAN, amp_std) with dist.Normal(mu_from_PET, sigma_from_PET) is a one-line change, but this notebook does not work through an empirical example.
  • High-dim inference is mostly a hardware question. Reverse-mode AD keeps the gradient cost almost flat in parameter count, so HMC on a thousand parameters runs given a GPU and a long-running workload. The real limit is posterior geometry (multimodality, conditioning), not raw dimension. The final section discusses what scales well and when to swap method.

The notebook proceeds in four steps:

  1. Generate a noisy observation from a known ground truth.
  2. Scan the MSE landscape and show that many parameter pairs explain the data.
  3. Sample three posteriors with different priors on the same observation.
  4. Compare those posteriors against multi-start gradient descent on the same priors.
TipWhat you’ll learn
  • Concepts: what parameter degeneracy looks like, why a point estimate hides it, and how a prior turns “any answer on the ridge” into “the answer consistent with what you already know”.
  • TVB-Optim idioms: wiring tvboptim forward simulation into a numpyro model, using GridAxis for landscape scans, DataAxis to batch posterior-predictive draws.
  • Workflow: what changes when you swap an optimizer for a sampler, and how to read the result.
Environment Setup and Imports
# Set up environment — XLA_FLAGS must be set BEFORE importing jax. Here we expose
# 25 virtual devices so ParallelExecution can spread landscape scans, posterior
# predictive draws, and multi-start optimisations across many devices at once.
import os

N_DEVICES = 25
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={N_DEVICES}"

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import optax
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_density
from scipy.stats import gaussian_kde, iqr

from tvboptim.execution import ParallelExecution
from tvboptim.experimental.network_dynamics import Network, prepare, solve as nd_solve
from tvboptim.experimental.network_dynamics.dynamics.tvb import Generic2dOscillator
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseGraph
from tvboptim.experimental.network_dynamics.external_input import PulseInput
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.types import DataAxis, GridAxis, Space, Parameter, collect_parameters
from tvboptim.optim import OptaxOptimizer
from tvboptim.utils import set_cache_path, cache

# Cache directory for expensive MCMC and optimisation runs
set_cache_path("./bayes_stim")

Shared Settings and Priors

The three scenarios use identical prior means and differ only in their standard deviations. A tight prior on one parameter is how we tell the model that we already have evidence about it:

  • A: no hypothesis. Wide priors on both parameters.
  • B (H1, higher excitability): tight prior on amplitude. The stimulator log is trusted, so the data has to be explained via excitability.
  • C (H2, stronger stimulus): tight prior on excitability. PET or gene-expression evidence pins excitability down, so the data has to be explained via amplitude.
T1 = 150.0
DT = 0.2
ONSET = 10.0
DURATION = 1.0

TRUE_AMPLITUDE = 0.4
TRUE_EXCITABILITY = 0.1
OBS_NOISE_SIGMA = 0.1    # data generation noise
LIKELIHOOD_SIGMA = 0.2   # fixed in likelihood — intentionally flat along ridge so priors steer

PRIOR_AMP_MEAN = 0.2
PRIOR_EXC_MEAN = 0.0

# Scenario priors — all share the same means; each hypothesis constrains one
# parameter with a tight σ:
#   A: no hypothesis.            Wide priors on both.
#   B: H1, higher excitability.  Tight amp prior → excitability has to explain.
#   C: H2, stronger stimulus.    Tight exc prior → amplitude has to explain.
PRIOR_AMP_STD = {"A": 0.20, "B": 0.10, "C": 0.20}
PRIOR_EXC_STD = {"A": 0.10, "B": 0.10, "C": 0.05}

SCENARIO_LABELS = {
    "A": "No hypothesis\n(wide priors on both)",
    "B": f"H1: Higher excitability\namp ~ N({PRIOR_AMP_MEAN}, {PRIOR_AMP_STD['B']})",
    "C": f"H2: Stronger stimulus\nexc ~ N({PRIOR_EXC_MEAN}, {PRIOR_EXC_STD['C']})",
}

DYNAMICS_PARAMS = dict(a=-1.5, b=-15.0, c=0.0, d=0.015, e=3.0, f=1.0, tau=4.0)
weights = jnp.zeros((1, 1))

SCENARIO_KEYS = ["A", "B", "C"]
COLORS = ["tab:blue", "tab:green", "tab:orange"]

Stimulation in TVB-Optim: quick recap

The stimulus in this notebook is a rectangular pulse delivered to a single-node Generic2dOscillator. In TVB-Optim, stimuli live in the external_input slot of a Network and are routed by name. The dynamics class declares which names it accepts via its EXTERNAL_INPUTS attribute; for Generic2dOscillator the name is 'stimulus', and the value gets added into the equation for the fast variable V.

The pulse itself is a PulseInput parametrized by three numbers:

  • onset — when the pulse starts (ms).
  • duration — how long it stays on.
  • amplitude — the constant value driving V during the on-window.

PulseInput is a parametric input: a function of time that the solver evaluates at each step. The alternative is DataInput, which interpolates between recorded samples; use it when you have an experimental stimulator trace rather than an idealized waveform. Any parameter of an external input can be wrapped in Parameter(...) to mark it as optimizable, which is exactly what we do later when comparing Bayesian inference against gradient descent.

In code, the wiring lives inside build_network (in the collapsed helpers block below):

Network(
    dynamics=Generic2dOscillator(..., I=excitability, ...),
    coupling={"instant": LinearCoupling(...)},
    graph=DenseGraph(weights),
    external_input={"stimulus": PulseInput(
        onset=ONSET, duration=DURATION, amplitude=amplitude,
    )},
)

The name "stimulus" on the left of the dict has to match the name in EXTERNAL_INPUTS. The two free parameters of this notebook are amplitude (a property of the input) and I (a property of the dynamics, the excitability). They sit in different parts of the config, but the Bayesian model treats them symmetrically.

External inputs carry a .plot() method, which is the quickest way to confirm the pulse looks like what you wrote down:

PulseInput(onset=ONSET, duration=DURATION, amplitude=TRUE_AMPLITUDE).plot(t0=0, t1=100)
Figure 1: The pulse stimulus at the true parameters. A 1 ms rectangular pulse of amplitude TRUE_AMPLITUDE = 0.4 starting at ONSET = 10 ms. This is the time course added to the V equation of the Generic2dOscillator during the on-window.

On the likelihood width

The data-generating noise is OBS_NOISE_SIGMA = 0.1 but the likelihood is scored with LIKELIHOOD_SIGMA = 0.2. Why the mismatch?

A matched σ produces a posterior already tight enough around the truth that the prior choice barely moves it, hiding the demonstration. Inflating the likelihood σ keeps it flat along the degeneracy ridge, so the prior is what decides where the posterior concentrates. The trade-off is explicit: posteriors are deliberately wider than a maximum-likelihood analysis would produce. Try the run with LIKELIHOOD_SIGMA = OBS_NOISE_SIGMA and the ridge shortens but does not disappear, just over a narrower region.

The Bayesian Model

This is what “transitioning from point estimates to full posteriors” looks like in code: the forward model is unchanged, the loss is replaced by a likelihood, and a sampler explores the joint of priors × likelihood. make_model returns a numpyro model that describes the joint distribution of parameters and data. Three things happen inside the inner model function:

  1. Priors. numpyro.sample("amplitude", dist.Normal(...)) draws an amplitude from the prior. It also tells numpyro: this is a latent variable named “amplitude”, track its log-density. Same for excitability. The standard deviations amp_std and exc_std are scenario-specific and baked in from the closure.
  2. Forward simulation. The sampled parameters are written into the simulation config and the forward model produces a predicted trajectory v_pred.
  3. Likelihood. The final numpyro.sample("obs", ..., obs=v_obs) declares that the observation v_obs is a noisy version of v_pred with fixed noise scale LIKELIHOOD_SIGMA. The obs= keyword is what turns this from “sample obs” into “score obs against the model”.
def make_model(scenario_key):
    """Return a numpyro model with scenario-specific prior widths baked in as closure constants."""
    amp_std = float(PRIOR_AMP_STD[scenario_key])
    exc_std = float(PRIOR_EXC_STD[scenario_key])

    def model(v_obs, solve_fn, config, obs_idx):
        # 1. Priors on the two latent parameters
        amplitude    = numpyro.sample("amplitude",    dist.Normal(PRIOR_AMP_MEAN, amp_std))
        excitability = numpyro.sample("excitability", dist.Normal(PRIOR_EXC_MEAN, exc_std))

        # 2. Forward simulation with the sampled parameters
        config.external.stimulus.amplitude = amplitude
        config.dynamics.I = excitability
        v_pred = solve_fn(config).ys[obs_idx, 0, 0]

        # 3. Likelihood: observation is v_pred + Gaussian noise
        numpyro.sample("obs", dist.Normal(v_pred, LIKELIHOOD_SIGMA), obs=v_obs)

    return model

NUTS then samples the joint posterior over amplitude and excitability by combining (1), (2), and (3). The sampler never sees the priors as Python objects: it only sees the resulting log-density. That is why amp_std and exc_std are converted to Python floats before being captured in the closure. Passing live dist.Normal objects in from the outside would break JIT tracing and silently drop the priors.

The model is decoupled from the sampler. The same model accepts three inference paths with no model change — pick by use case:

# 1. NUTS — full posterior (this notebook's default; see run_mcmc in the helpers)
mcmc = MCMC(NUTS(model, dense_mass=True), num_warmup=500, num_samples=2000)
mcmc.run(rng, v_obs, sf, cfg, obs_idx)
posterior_samples = mcmc.get_samples()

# 2. SVI with an auto-guide — parametric posterior, much cheaper, less faithful in the tails
from numpyro.infer import SVI, Trace_ELBO, autoguide
guide = autoguide.AutoMultivariateNormal(model)
svi_result = SVI(model, guide, optax.adam(1e-2), Trace_ELBO()).run(
    rng, 2000, v_obs, sf, cfg, obs_idx,
)
svi_samples = guide.sample_posterior(rng, svi_result.params, sample_shape=(2000,))

# 3. MAP — same machinery, point estimate (the Bayesian relative of gradient descent)
map_run = SVI(model, autoguide.AutoDelta(model), optax.adam(1e-2), Trace_ELBO()).run(
    rng, 1000, v_obs, sf, cfg, obs_idx,
)

NUTS gives a faithful posterior but pays for it in gradient evaluations. SVI fits a parametric family — fast, but the family choice is a modelling decision and tail behaviour can be off. MAP recovers a point estimate within the same framework, and is conceptually what the multi-start optimisation later in this notebook approximates with a flat prior.

These are plumbing: the network factory, the MSE loss for the optimization comparison, the MCMC runner with NUTS settings, and a small landscape-drawing helper used by several figures.

def build_network(amplitude, excitability):
    """Deterministic forward model network."""
    return Network(
        dynamics=Generic2dOscillator(**DYNAMICS_PARAMS, I=excitability, VARIABLES_OF_INTEREST=("V",)),
        coupling={"instant": LinearCoupling(incoming_states="V", G=0.0)},
        graph=DenseGraph(weights),
        external_input={"stimulus": PulseInput(onset=ONSET, duration=DURATION, amplitude=amplitude)},
    )


def make_loss(solve_fn):
    """MSE loss against observed data, closed over solve_fn."""
    def loss(config):
        return jnp.mean((solve_fn(config).ys[obs_idx, 0, 0] - v_obs) ** 2)
    return loss


def run_mcmc(model_fn, seed, label, num_warmup=500, num_samples=2000, num_chains=1):
    print(f"\n{'='*60}\n{label}\n{'='*60}")
    net = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
    sf, cfg = prepare(net, Heun(), t0=0.0, t1=T1, dt=DT)
    nuts = NUTS(
        model_fn,
        max_tree_depth=10,
        dense_mass=True,        # learns ridge correlation → explores along it
        target_accept_prob=0.8,
    )
    mcmc = MCMC(nuts, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains)
    mcmc.run(jax.random.key(seed), v_obs, sf, cfg, obs_idx)
    mcmc.print_summary()
    return mcmc.get_samples(group_by_chain=False)


def _draw_landscape(ax, vmax=None):
    """Draw MSE heatmap + 10th-percentile contour + ground truth marker; return pcm."""
    pcm = ax.pcolormesh(amp_vals, excit_vals, mse_grid,
                        cmap="cividis_r", vmax=vmax or vmax_clip)
    ax.contour(amp_vals, excit_vals, mse_grid,
               levels=[float(jnp.percentile(mse_grid, 10))],
               colors="white", linewidths=1.0, linestyles="--")
    ax.scatter(TRUE_AMPLITUDE, TRUE_EXCITABILITY,
               color="k", marker="*", s=150, zorder=6, label="ground truth")
    ax.set_xlabel("Stimulus amplitude")
    ax.set_ylabel("Excitability (I)")
    # Clamp axes to the scanned landscape extent so later scatter() calls
    # (posterior samples, optimisation endpoints) don't trigger autoscale.
    ax.set_xlim(float(amp_vals[0]),   float(amp_vals[-1]))
    ax.set_ylim(float(excit_vals[0]), float(excit_vals[-1]))
    return pcm

Generating the Observation

Simulate the forward model at the true parameters, subsample every 15 steps, add Gaussian noise. The noiseless trajectory is kept around as the reference for MSE in the landscape scan that follows.

# Noiseless signal as base ensures MLE at true params. Ridge points produce near-identical
# noiseless signals, so the likelihood is flat along the ridge and priors steer freely.
network_true = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
solve_true, config_true = prepare(network_true, Heun(), t0=0.0, t1=T1, dt=DT)
v_noiseless = solve_true(config_true).ys[:, 0, 0]
ts = solve_true(config_true).ts

obs_idx = jnp.arange(0, len(ts), 15)
ts_obs = ts[obs_idx]
v_obs = v_noiseless[obs_idx] + OBS_NOISE_SIGMA * jax.random.normal(
    jax.random.key(42), (len(obs_idx),)
)
Figure 2: Observed data and underlying signal. Blue: noiseless ground-truth trajectory at amp=0.4, I=0.1. Black points: subsampled observations with additive Gaussian noise. Red band: the stimulus pulse window.

Mapping the Loss Landscape

Before any inference, scan the 2D parameter space and compute MSE between the deterministic simulation and the noiseless target at each grid point. The result is the degeneracy landscape: the set of (amplitude, excitability) pairs the data cannot distinguish.

N_AMP = N_DEVICES
N_EXCIT = N_DEVICES

net_scan = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
sf_scan, cfg_scan = prepare(net_scan, Heun(), t0=0.0, t1=T1, dt=DT)

cfg_scan.external.stimulus.amplitude = GridAxis(0.0, 1.0, N_AMP)
cfg_scan.dynamics.I = GridAxis(-0.1, 0.5, N_EXCIT)
space = Space(cfg_scan, mode="product")

par = ParallelExecution(
    model=lambda cfg: sf_scan(cfg).ys[:, 0, 0],
    space=space, n_pmap=N_DEVICES, n_vmap=N_DEVICES,
)
scan_results = par.run()
print("Landscape scan complete.")

df_scan = scan_results.to_dataframe()
amp_vals   = jnp.array(sorted(df_scan["external.stimulus.amplitude"].unique()))
excit_vals = jnp.array(sorted(df_scan["dynamics.I"].unique()))

mse_grid = jnp.zeros((len(excit_vals), len(amp_vals)))
for row_idx, row in df_scan.iterrows():
    mse   = float(jnp.mean((scan_results[row_idx] - v_noiseless) ** 2))
    i_amp = int(jnp.argmin(jnp.abs(amp_vals   - row["external.stimulus.amplitude"])))
    i_ex  = int(jnp.argmin(jnp.abs(excit_vals - row["dynamics.I"])))
    mse_grid = mse_grid.at[i_ex, i_amp].set(mse)

vmax_clip = float(jnp.percentile(mse_grid, 75))

Pick the eight lowest-MSE grid points (excluding a small neighborhood of the ground truth) and overlay their simulated trajectories. They are visually nearly identical despite coming from qualitatively different mechanistic explanations.

Figure 3: The degeneracy ridge. Top: MSE landscape with the eight lowest-loss off-truth grid points marked. Bottom: their simulated trajectories overlaid on the ground truth. They are visually almost identical despite different (amp, I) combinations. This is the ambiguity any inference method has to deal with.

Sanity-Checking the Priors

Long MCMC runs hide silent bugs. Before launching one, evaluate the log joint density at two points under scenario C (tight excitability prior). A point with high |I| should score much worse than one with low |I| if the tight prior is actually active. If the two scores come out similar, the priors were dropped during tracing and the whole experiment is wrong.

_net_chk = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
_sf_chk, _cfg_chk = prepare(_net_chk, Heun(), t0=0.0, t1=T1, dt=DT)
_args = (v_obs, _sf_chk, _cfg_chk, obs_idx)

# model_C has tight excitability prior → high-|I| points should score worse
_model_C = make_model("C")
lp_low,  _ = log_density(_model_C, _args, {}, {"amplitude": 0.5, "excitability": 0.0})
lp_high, _ = log_density(_model_C, _args, {}, {"amplitude": 0.1, "excitability": 0.3})
print(f"model_C  log p(amp=0.50, I=0.0):  {lp_low:.1f}")
print(f"model_C  log p(amp=0.10, I=0.3):  {lp_high:.1f}")
print(f"Difference: {lp_low - lp_high:.1f}  (should be >> 0 if tight exc prior is active)")
model_C  log p(amp=0.50, I=0.0):  30.6
model_C  log p(amp=0.10, I=0.3):  -123.3
Difference: 153.8  (should be >> 0 if tight exc prior is active)

Running MCMC Under Three Hypotheses

Sample three posteriors with NUTS. dense_mass=True matters here: the posterior is a thin diagonal ridge, and a learned full mass matrix lets the sampler step along the ridge instead of zig-zagging across it.

MCMC_KWARGS = dict(num_warmup=500, num_samples=4000, num_chains=1)

# @cache stores the function's return value on disk. On rerun the cached samples
# are loaded instead of resampling — set redo=True to force re-running the chain
# (e.g. after changing the priors or likelihood).
@cache("mcmc_samples", redo=False)
def run_all_mcmc():
    return [
        run_mcmc(make_model("A"), seed=0, label="A: No hypothesis (uninformative priors)", **MCMC_KWARGS),
        run_mcmc(make_model("B"), seed=1, label="H1: Higher excitability (e.g. PET, gene expression)", **MCMC_KWARGS),
        run_mcmc(make_model("C"), seed=2, label="H2: Stronger stimulus (e.g. device logs, protocol)", **MCMC_KWARGS),
    ]

all_samples = run_all_mcmc()
Figure 4: Joint posteriors overlaid on the degeneracy landscape. Same observation, three hypotheses. Wide priors (A) let samples sprawl along the entire ridge; the tight-amplitude prior (B) confines them to a high-excitability segment; the tight-excitability prior (C) confines them to a high-amplitude segment. The data has not changed across panels, only the slice of the ridge the model is allowed to consider.
Figure 5: Marginal posteriors versus priors. Dashed grey: prior density. Solid coloured: posterior KDE. Black: ground truth. Top row: amplitude; bottom row: excitability. A wide prior produces a data-driven posterior. A tight prior produces a posterior that tracks the prior, and the other parameter shifts to absorb the discrepancy.

Posterior Predictive

Push the posterior back through the forward model: draw N_PP_DRAWS parameter pairs per scenario, simulate each one, overlay the trajectories. The spread of the resulting ensemble is the model’s uncertainty about future observations under that hypothesis, expressed in the same units as the data.

# Semi-transparent traces from posterior draws → uncertainty spread;
# posterior-mean prediction as solid line → point summary.
N_PP_DRAWS = 2 * N_DEVICES  # must be divisible into N_DEVICES × n_vmap

pp_traces = {}
pp_means  = {}

for key, samples in zip(SCENARIO_KEYS, all_samples):
    n_total  = len(samples["amplitude"])
    draw_idx = jnp.linspace(0, n_total - 1, N_PP_DRAWS).astype(int)

    net_pp = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
    sf_pp, cfg_pp = prepare(net_pp, Heun(), t0=0.0, t1=T1, dt=DT)

    cfg_pp.external.stimulus.amplitude = DataAxis(samples["amplitude"][draw_idx])
    cfg_pp.dynamics.I                  = DataAxis(samples["excitability"][draw_idx])
    space_pp = Space(cfg_pp, mode="zip")

    par_pp = ParallelExecution(
        model=lambda cfg: sf_pp(cfg).ys[:, 0, 0],
        space=space_pp, n_pmap=N_DEVICES, n_vmap=N_PP_DRAWS // N_DEVICES,
    )
    results = par_pp.run()
    pp_traces[key] = jnp.array([results[i] for i in range(N_PP_DRAWS)])
    pp_means[key]  = (float(jnp.mean(samples["amplitude"])),
                      float(jnp.mean(samples["excitability"])))

print("Posterior predictive simulations complete.")
Figure 6: Posterior-predictive simulations. Faint coloured lines: forward simulations from individual posterior draws. Solid coloured line: simulation at the posterior mean. Black dots: the original observations. All three scenarios fit the data equally well. They disagree only on why.

Comparison: Multi-Start Optimization

Skip Bayesian inference and run gradient descent instead. Sample starting points from each scenario’s prior, run Adam for 1000 steps from each one. Every run collapses to a single point on the ridge. The optimizer cannot represent the residual uncertainty, and the starting distribution (the prior) decides which segment of the ridge each run lands in.

# Gradient descent from prior-sampled starts reveals:
# (a) each run converges to a single point on the ridge — no uncertainty
# (b) the starting distribution (prior) determines which ridge segment is found
# (c) there is no principled uncertainty quantification
N_SAMPLES     = 12 * N_DEVICES
N_OPTIM_STEPS = 1000

# Cache the multi-start optimisation — 12*N_DEVICES Adam runs per scenario × 3
# scenarios is expensive. Set redo=True to rerun after changing learning rate,
# step count, or priors.
@cache("multistart_optim", redo=False)
def run_multistart_optim():
    key_opt     = jax.random.key(99)
    opt_results = {}

    for scenario in SCENARIO_KEYS:
        key_opt, k_amp, k_exc = jax.random.split(key_opt, 3)
        amp_starts = PRIOR_AMP_MEAN + PRIOR_AMP_STD[scenario] * jax.random.normal(k_amp, (N_SAMPLES,))
        exc_starts = PRIOR_EXC_MEAN + PRIOR_EXC_STD[scenario] * jax.random.normal(k_exc, (N_SAMPLES,))

        net_multi = build_network(PRIOR_AMP_MEAN, PRIOR_EXC_MEAN)
        sf_multi, cfg_multi = prepare(net_multi, Heun(), t0=0.0, t1=T1, dt=DT)

        cfg_multi.external.stimulus.amplitude = DataAxis(amp_starts)
        cfg_multi.dynamics.I = DataAxis(exc_starts)
        space_opt = Space(cfg_multi, mode="zip")

        optimizer = OptaxOptimizer(make_loss(sf_multi), optax.adam(learning_rate=0.1))

        def run_optim(config):
            config.external.stimulus.amplitude = Parameter(config.external.stimulus.amplitude)
            config.dynamics.I = Parameter(config.dynamics.I)
            p_fit, _ = optimizer.run(config, max_steps=N_OPTIM_STEPS, chunk_size=N_OPTIM_STEPS)
            return jnp.array([
                collect_parameters(p_fit.external.stimulus.amplitude),
                collect_parameters(p_fit.dynamics.I),
            ])

        par_opt = ParallelExecution(
            model=run_optim, space=space_opt,
            n_pmap=N_DEVICES, n_vmap=N_SAMPLES // N_DEVICES,
        )
        results_opt = par_opt.run()

        final_params = jnp.array([results_opt[i] for i in range(N_SAMPLES)])
        opt_results[scenario] = {
            "amplitude":    final_params[:, 0],
            "excitability": final_params[:, 1],
        }
        print(f"Scenario {scenario}: {N_SAMPLES} optimisations complete.")

    return opt_results

opt_results = run_multistart_optim()
Figure 7: Multi-start optimization endpoints. Each coloured dot is the final (amp, I) of one gradient-descent run. Endpoints cluster onto narrow regions of the ridge, but those regions just track where the starting points (sampled from the prior) happened to lie. The clustering is not a posterior. It is a smeared-out projection of the prior, with no notion of uncertainty attached.
Figure 8: Bayesian posteriors vs optimization endpoints, side by side. Solid lines: Bayesian posterior densities. Translucent histograms: optimization endpoints from the same priors. The two are answers to different questions. Optimization tells you where the loss is low given a starting bias. Bayes tells you what to believe about the parameter given the data and a prior.

Summary Table

The per-scenario numbers in one place. For each parameter and each scenario: mean and std, median and IQR, computed once from the Bayesian posterior and once from the multi-start optimization endpoints.

Figure 9: Summary statistics by scenario and method. For each scenario (A/B/C) and each parameter (amplitude, excitability), we report mean ± std and median ± IQR for the Bayesian posterior and for the multi-start optimization endpoints. Compare across rows to see how the same observation produces different inferences under different hypotheses.

When HMC stops being the right tool

Two parameters is the easy case. The harder question is what happens when the posterior lives in hundreds or thousands of dimensions, as it does for the connectivity-scale optimisation in the main paper.

What scales. The forward model is a JAX function, so one gradient call costs one forward plus one reverse-mode pass, almost independently of how many parameters the gradient is taken with respect to. For moderate parameter counts (roughly 10 to a few hundred) HMC with dense_mass=True works well, and the integration above is the template.

What does not scale. A dense mass matrix is O(d²) to store and learn, the warm-up cost grows with it, and multimodality compounds as the dimension rises. The cost per leapfrog step is the forward-model cost, which is fine here but painful for a long BOLD pipeline. For N = 14,028 parameters, vanilla NUTS is not the answer.

What is the answer at that scale. Three alternatives stay inside the same differentiable workflow:

  1. Variational inference: replace the posterior with a parametric family and minimise a KL objective using the same gradients. JAX-friendly; numpyro has SVI built in.
  2. Simulation-based inference: train a density estimator on (parameters, summary statistic) pairs sampled from the prior and the simulator. Discussed below.
  3. Population-based search: pymoo + JAX gives multi-objective Pareto fronts using the same ParallelExecution backbone, without leaving the JAX toolchain. Useful when uncertainty quantification is less important than exploring competing objectives.

Off-ramp: simulation-based inference

SBI sidesteps the likelihood entirely. Sample parameters from a prior, run the simulator on each, summarise the output, and train a neural density estimator on those (theta, x) pairs. The trained network is a posterior conditioned on any observed summary. This trades the explicit likelihood for a learned inverse map, which is exactly the trade you want when the forward model is expensive, the likelihood is hard to write down, or the parameter count is large.

The integration with tvboptim is short because the simulator side is already done: ParallelExecution over DataAxis is the batched simulator that SBI needs. The sketch below shows the shape of the wiring (display only, not executed in this notebook):

import torch
from sbi.inference import SNPE
from sbi.utils import BoxUniform

# 1. Prior in torch — sbi requires torch distributions
prior = BoxUniform(low=torch.tensor([0.0, -0.3]),
                   high=torch.tensor([1.0,  0.6]))
theta = prior.sample((N_SIM,))  # (N_SIM, 2)

# 2. Forward sims on the JAX side via ParallelExecution
net = build_network(TRUE_AMPLITUDE, TRUE_EXCITABILITY)
sf, cfg = prepare(net, Heun(), t0=0.0, t1=T1, dt=DT)
cfg.external.stimulus.amplitude = DataAxis(jnp.asarray(theta[:, 0]))
cfg.dynamics.I                  = DataAxis(jnp.asarray(theta[:, 1]))
sims = ParallelExecution(
    model=lambda c: summary_stats(sf(c).ys[obs_idx, 0, 0]),
    space=Space(cfg, mode="zip"),
    n_pmap=N_DEVICES, n_vmap=N_SIM // N_DEVICES,
).run()

# 3. Hand back to sbi to train a posterior estimator
x = torch.as_tensor(np.asarray(sims))
posterior = SNPE(prior=prior).append_simulations(theta, x).train()
samples = posterior.sample((4000,), x=torch.as_tensor(np.asarray(v_obs_summary)))

Trade-offs:

  • What this buys. It scales to high-dim, accepts non-differentiable summaries, needs no closed-form likelihood, and reuses the same parallel forward model. The summary statistic can be any reduction: peak amplitude, FCD entries, BOLD KS distance.
  • What it costs. torch becomes a runtime dependency, the JAX↔︎torch boundary is a numpy hop, and the training step lives outside the inference loop. On multi-GPU JAX setups the combined toolchain is fragile, which is why tvboptim does not vendor it.
  • When to skip SBI entirely. For multi-objective optimisation without Bayesian uncertainty, pymoo on top of ParallelExecution gives Pareto fronts using the JAX toolchain alone. That is often enough when the question is “which parameter regimes satisfy which objectives” rather than “what is the posterior”.
NoteTakeaways
  • The forward model is degenerate: a ridge of (amplitude, excitability) pairs produces near-identical observations.
  • Optimization collapses that ridge to a single arbitrary point. It is silent about residual uncertainty.
  • Bayesian inference returns the full posterior along the ridge. The prior is where mechanistic evidence (a PET scan, a stimulator log) enters in a way the optimizer cannot represent.
  • All three scenarios fit the data equally well. They differ only in which mechanistic explanation they prefer. That is the practical payoff of making priors explicit.
  • The same ridge geometry appears whenever a forward model has internal trade-offs, including the effective-connectivity case where EI-tuning and Loss-based optimisation produce different solutions for the same data. The remedy there is the same as here: state the priors and report a posterior.