Extending Coupling - Surface Simulations

Recreating TVB’s Surface Simulations with Subspace Coupling

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

Subspace coupling enables hierarchical network interactions where fine-grained nodes are organized into coarser regions. This pattern is fundamental in brain network modeling for cortical surface simulations, where thousands of vertices belong to anatomical parcels that communicate via long-range white matter tracts.

The key idea: instead of coupling all nodes directly, we aggregate node states to regions, apply coupling at the regional level, then distribute results back to nodes. This three-stage pattern is implemented through the standard coupling API: prepare(), compute(), and update_state().

Use Case: Surface Simulations

In The Virtual Brain (TVB), surface simulations model cortical activity at the vertex level (10 thousands of vertices) while maintaining regional connectivity via structural connectomes (hundreds of regions). This creates a two-level hierarchy:

  • Node level: Fine-grained cortical vertices with local connectivity
  • Regional level: Coarse anatomical parcels with long-range delayed connectivity
Environment Setup and Imports
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
import time
from tvboptim.experimental.network_dynamics import Network, solve
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling.linear import FastLinearCoupling, DelayedLinearCoupling
from tvboptim.experimental.network_dynamics.coupling.subspace import SubspaceCoupling
from tvboptim.experimental.network_dynamics.graph import SparseGraph, DenseDelayGraph
from tvboptim.experimental.network_dynamics.solvers import Euler

The Subspace Coupling Pattern

Conceptual Overview

Subspace coupling operates in three stages:

1. AGGREGATE:  [n_states, n_nodes] → [n_states, n_regions]
2. COUPLE:     Apply coupling at regional level
3. DISTRIBUTE: [n_coupling, n_regions] → [n_coupling, n_nodes]

Stage 1 (Aggregate): Average node states within each region \[ s_r = \frac{1}{|R_r|} \sum_{i \in R_r} s_i \]

Stage 2 (Couple): Apply inner coupling on regional graph \[ c_r = \sum_{r'} w_{rr'} f(s_{r'}(t - \tau_{rr'})) \]

Stage 3 (Distribute): Broadcast regional coupling to constituent nodes \[ c_i = c_{\text{region}(i)} \]

Implementation: The Coupling API

Let’s see how this pattern maps to the three-phase coupling API. We’ll show conceptual code snippets that capture the essence of SubspaceCoupling.

Phase 1: prepare() - Precompute Regional Structures

def prepare(self, network, dt, t0, t1):
    """Precompute aggregation matrices and prepare inner coupling."""

    # Build normalized aggregation matrix [n_nodes, n_regions]
    # Each node contributes 1/|region| to its region's mean
    region_one_hot = jnp.eye(self.n_regions)[self.region_mapping]
    region_counts = jnp.sum(region_one_hot, axis=0)
    region_one_hot_normalized = region_one_hot / region_counts[None, :]

    # Create regional network context for inner coupling
    # This provides a Network-like interface for the regional graph
    regional_context = self._create_regional_context(network, ...)

    # Prepare inner coupling (e.g., DelayedLinearCoupling on regional graph)
    inner_data, inner_state = self.inner_coupling.prepare(
        regional_context, dt, t0, t1
    )

    # Compute initial aggregated state for caching
    initial_regional_state = self.aggregate(
        network.initial_state, region_one_hot_normalized
    )

    return coupling_data, coupling_state

Key insight: We precompute the aggregation matrix once, avoiding repeated computation. For sparse mappings (many nodes, few regions), we use sparse matrices.

Phase 2: compute() - Three-Stage Computation

def compute(self, t, state, coupling_data, coupling_state, params, graph):
    """Aggregate → couple → distribute."""

    # Stage 1: Use cached aggregated regional state
    # (computed in previous update_state, avoids redundant aggregation)
    regional_state = coupling_state.cached_regional_state
    # Shape: [n_states, n_regions]

    # Stage 2: Compute coupling at regional level
    regional_coupling = self.inner_coupling.compute(
        t,
        regional_state,
        coupling_data.inner_data,
        coupling_state.inner_state,
        self.inner_coupling.params,
        self.regional_graph  # Regional connectivity
    )
    # Shape: [n_coupling_inputs, n_regions]

    # Stage 3: Distribute regional coupling to nodes (broadcast)
    node_coupling = self.distribute(
        regional_coupling, coupling_data.region_mapping
    )
    # Shape: [n_coupling_inputs, n_nodes]

    return node_coupling

Performance optimization: We cache the aggregated state from the previous timestep, avoiding redundant aggregation in compute().

Phase 3: update_state() - Update and Cache

def update_state(self, coupling_data, coupling_state, new_state):
    """Update inner coupling state and cache new aggregated state."""

    # Aggregate new node state to regional state
    regional_state = self.aggregate(
        new_state, coupling_data.region_one_hot_normalized
    )

    # Update inner coupling (e.g., delay buffer with regional states)
    new_inner_state = self.inner_coupling.update_state(
        coupling_data.inner_data,
        coupling_state.inner_state,
        regional_state  # Regional state for delay buffer
    )

    # Cache aggregated state for next compute()
    return Bunch(
        inner_state=new_inner_state,
        cached_regional_state=regional_state
    )

Temporal optimization: The aggregated state computed here is reused in the next compute() call, eliminating duplicate aggregation operations.

Aggregate and Distribute Methods

These customizable methods implement the projection between node and regional spaces:

def aggregate(self, node_state, coupling_data):
    """Aggregate node states to regional states (default: mean)."""
    # node_state: [n_states, n_nodes]
    # region_one_hot_normalized: [n_nodes, n_regions]

    regional_state = node_state @ coupling_data.region_one_hot_normalized
    # Shape: [n_states, n_regions]

    return regional_state

def distribute(self, regional_coupling, coupling_data):
    """Distribute regional coupling to nodes (default: broadcast)."""
    # regional_coupling: [n_coupling_inputs, n_regions]
    # region_mapping: [n_nodes] with region IDs

    node_coupling = regional_coupling[:, coupling_data.region_mapping]
    # Shape: [n_coupling_inputs, n_nodes]

    return node_coupling

Customization: Override these methods for alternative strategies (weighted aggregation, scaled distribution, etc.).

Practical Example: Mixed Coupling

Let’s create a realistically sized surface simulation with both local and regional coupling:

# Network dimensions
n_nodes = 16000  # Cortical vertices
n_regions = 76   # Brain regions
t0, t1, dt = 0.0, 1000.0, 1.0

# Regional connectivity: Structural connectome with delays
region_mapping = jax.random.randint(
    jax.random.key(42), (n_nodes,), 0, n_regions
)

regional_graph = DenseDelayGraph.random(
    n_nodes=n_regions,
    sparsity=0.8,      # 80% of connections present
    max_delay=50.0,    # ~150mm at 3 m/s
    key=jax.random.key(0)
)

# Local connectivity: Sparse short-range connections
node_graph = SparseGraph.random(
    n_nodes=n_nodes,
    sparsity=0.000366,  # 0.0366% connectivity (typical density, depending on kernel)
    key=jax.random.key(1)
)

print(f"Local graph: {node_graph.nnz:,} edges (sparse)")
print(f"Regional graph: {n_regions}×{n_regions} (dense with delays)")
Local graph: 93,690 edges (sparse)
Regional graph: 76×76 (dense with delays)

Network Construction

Now create a network with two coupling types:

  1. Instantaneous local coupling: Fast connections between nearby vertices
  2. Delayed regional coupling: Long-range connections between brain regions
# Local instantaneous coupling
coupling_instant = FastLinearCoupling(local_states='S', G=0.2)

# Regional delayed coupling via subspace
coupling_delayed = SubspaceCoupling(
    inner_coupling=DelayedLinearCoupling(incoming_states='S', G=0.05),
    region_mapping=region_mapping,
    regional_graph=regional_graph,
)

# Multi-coupling network
network = Network(
    dynamics=ReducedWongWang(I_o = 0.1),
    coupling={
        'instant': coupling_instant,   # Local vertices
        'delayed': coupling_delayed    # Regional subspace
    },
    graph=node_graph
)

print(network)
Network(
  dynamics=ReducedWongWang
  nodes=16000
  couplings=['instant', 'delayed']
)

Key point: The dynamics model (ReducedWongWang) declares two coupling inputs (instant and delayed). The network provides both through named couplings operating at different spatial scales.

Simulation

start = time.time()
result = solve(network, Euler(), t0=t0, t1=t1, dt=dt)
elapsed = time.time() - start

print(f"\nSimulation: {elapsed:.2f} seconds")
print(f"Result shape: {result.ys.shape}")
print(f"Time points: [{result.ts[0]:.1f}, {result.ts[-1]:.1f}] ms")

Simulation: 1.20 seconds
Result shape: (1000, 1, 16000)
Time points: [0.0, 1000.0] ms

Network Inspection

The network printer reveals the hierarchical structure:

from tvboptim.experimental.network_dynamics.utils.printer import print_network

print_network(network)
 Network Dynamics Network System
==================================================

Dynamics: ReducedWongWang
  States: S
  Initial: S=0.1

Graph: SparseGraph
  Nodes: 16000
  Density: 0.037%

Couplings
--------------------------------------------------
1. instant (FastLinearCoupling)
   Type: instantaneous
   States: local=S
   Form: 0.2 * Σⱼ wᵢⱼ * Sⱼ + 0.0
   post: 0.2 * (...) + 0.0
   params: G=0.2, b=0.0

2. delayed (Subspace(DelayedLinearCoupling))
   Type: delayed
   Regions: 76
   Aggregation: mean
   Distribution: broadcast
   Form: [76 regions] post(Σᵣ wᵢᵣ * Sᵣ(t - τᵢᵣ))
   post: 0.05 * (...) + 0.0
   Max delay: 49.99186325073242 ms


Dynamics Equations
--------------------------------------------------
    def dynamics(self, t: float, state: jnp.ndarray, params: Bunch, coupling: Bunch, external: Bunch, ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # Unpack parameters
        a, b, d = params.a, params.b, params.d
        gamma, tau_s = params.gamma, params.tau_s
        w, J_N, I_o = params.w, params.J_N, params.I_o

        # Unpack state and coupling
        S = state[0]  # Synaptic gating variable
        c_instant = coupling.instant[0]
        # ↳ instant: 0.2 * Σⱼ wᵢⱼ * Sⱼ + 0.0
        c_delayed = coupling.delayed[0]
        # ↳ delayed: [76 regions] post(Σᵣ wᵢᵣ * Sᵣ(t - τᵢᵣ))

        # Total input to population (both couplings add via J_N)
        x = w * J_N * S + I_o + J_N * c_instant + J_N * c_delayed

        # Transfer function H(x)
        ax_minus_b = a * x - b
        H = ax_minus_b / (1 - jnp.exp(-d * ax_minus_b))

        # Population dynamics
        dS_dt = -(S / tau_s) + (1 - S) * H * gamma

        # Package results
        derivatives = jnp.array([dS_dt])
        auxiliaries = jnp.array([H])

        return derivatives, auxiliaries


Parameters
--------------------------------------------------
  a=0.27, b=0.108, d=154, gamma=0.641, tau_s=100, w=0.6, J_N=0.261, I_o=0.1

Notice: - Coupling 1 (instant): Operates on node graph (16,000 nodes, sparse) - Coupling 2 (delayed): Shows regional structure (76 regions, aggregation/distribution) - Max delay: ~50 ms for long-range regional connections

History Management for Regional Coupling

When delayed regional coupling is used, the system needs history buffers at the regional level. The _RegionalNetworkContext handles this by aggregating node-level history to regional history.

Continuing Simulations

If you set custom history on the network (e.g., continuing from a previous simulation), the regional coupling automatically aggregates it:

# Run initial simulation
sim1 = solve(network, Euler(), t0=0.0, t1=500.0, dt=1.0)

# Set as history and continue
network.update_history(sim1)
sim2 = solve(network, Euler(), t0=500.0, t1=1000.0, dt=1.0)

print(f"First simulation:  {sim1.ys.shape}")
print(f"Second simulation: {sim2.ys.shape}")
First simulation:  (500, 1, 16000)
Second simulation: (500, 1, 16000)

The regional coupling’s get_history() method aggregates the node-level history to create appropriate regional delay buffers. This happens transparently via the shared history extraction utility.

Visualize Continued Simulation

Visualization: Continuity at Simulation Boundary
fig, axes = plt.subplots(2, 1, figsize=(8, 4.5), dpi = 200)

# Sample 100 vertices
sample_indices = jnp.linspace(0, n_nodes-1, 100, dtype=int)

# Vertex time series: sim1 and sim2
axes[0].plot(sim1.ts, sim1.ys[:, 0, sample_indices],
             alpha=0.3, linewidth=0.5, color='steelblue', label='Sim 1')
axes[0].plot(sim2.ts, sim2.ys[:, 0, sample_indices],
             alpha=0.3, linewidth=0.5, color='coral', label='Sim 2')
axes[0].axvline(500, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[0].set_ylabel('S (synaptic gating)')
axes[0].set_title(f'Cortical Activity: {len(sample_indices)} Vertices (Continued Simulation)')
axes[0].grid(True, alpha=0.3)

# Add zoom inset for vertices
axins0 = inset_axes(axes[0], width="30%", height="50%", loc='lower left',
                    bbox_to_anchor=(0.55, 0.1, 1, 1), bbox_transform=axes[0].transAxes)
axins0.plot(sim1.ts, sim1.ys[:, 0, sample_indices[:]],
            alpha=0.8, linewidth=1, color='steelblue')
axins0.plot(sim2.ts, sim2.ys[:, 0, sample_indices[:]],
            alpha=0.8, linewidth=1, color='coral')
axins0.axvline(500, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
axins0.set_xlim(480, 520)
axins0.set_ylim(sim1.ys[480:520, 0, sample_indices[:]].min() - 0.01,
                sim1.ys[480:520, 0, sample_indices[:]].max() + 0.01)
axins0.grid(True, alpha=0.3, linewidth=0.5)
axins0.tick_params(labelsize=7)
mark_inset(axes[0], axins0, loc1=2, loc2=4, fc="none", ec="0.5", linestyle='--', linewidth=0.5)

# Mean regional activity
regional_activity_sim1 = []
regional_activity_sim2 = []
for r in range(n_regions):
    mask = region_mapping == r
    regional_activity_sim1.append(jnp.mean(sim1.ys[:, 0, mask], axis=1))
    regional_activity_sim2.append(jnp.mean(sim2.ys[:, 0, mask], axis=1))

regional_activity_sim1 = jnp.array(regional_activity_sim1).T  # [time, n_regions]
regional_activity_sim2 = jnp.array(regional_activity_sim2).T

# Plot first 10 regions
axes[1].plot(sim1.ts, regional_activity_sim1[:, :10],
             alpha=0.7, linewidth=1.5, color='steelblue')
axes[1].plot(sim2.ts, regional_activity_sim2[:, :10],
             alpha=0.7, linewidth=1.5, color='coral')
axes[1].axvline(500, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[1].set_xlabel('Time [ms]')
axes[1].set_ylabel('Mean S per region')
axes[1].set_title('Regional Activity: 10 Brain Regions (Continued Simulation)')
axes[1].grid(True, alpha=0.3)

# Add zoom inset for regional activity
axins1 = inset_axes(axes[1], width="30%", height="50%", loc='lower left',
                    bbox_to_anchor=(0.55, 0.1, 1, 1), bbox_transform=axes[1].transAxes)
axins1.plot(sim1.ts, regional_activity_sim1[:, :10],
            alpha=0.8, linewidth=1.5, color='steelblue')
axins1.plot(sim2.ts, regional_activity_sim2[:, :10],
            alpha=0.8, linewidth=1.5, color='coral')
axins1.axvline(500, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
axins1.set_xlim(480, 520)
axins1.set_ylim(regional_activity_sim1[480:520, 0:10].min() - 0.005,
                regional_activity_sim1[480:520, 0:10].max() + 0.005)
axins1.grid(True, alpha=0.3, linewidth=0.5)
axins1.tick_params(labelsize=7)
mark_inset(axes[1], axins1, loc1=2, loc2=4, fc="none", ec="0.5", linestyle='--', linewidth=0.5)

plt.tight_layout()
plt.show()
/tmp/ipykernel_3252/3945729092.py:68: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  plt.tight_layout()

Key Implementation Insights

1. Hierarchical State Management

The coupling maintains two state representations: - Node states: [n_states, n_nodes] - full cortical surface - Regional states: [n_states, n_regions] - aggregated parcels

Aggregation happens via matrix multiplication with precomputed sparse/dense matrices.

2. Performance Optimizations

Temporal caching: Avoid redundant aggregation by caching the aggregated state from update_state() for use in next compute().

Spatial caching: Precompute aggregation matrices in prepare() rather than rebuilding each timestep.

Sparse operations: Use BCOO sparse format when n_nodes >> n_regions (enabled by default).

3. Nested Coupling Architecture

SubspaceCoupling wraps any inner coupling (instantaneous or delayed). The inner coupling sees only the regional graph, unaware it’s part of a hierarchical system. This composition enables:

  • Local + regional coupling combinations
  • Delayed regional coupling with instantaneous local coupling
  • Multiple regional couplings with different parameters

4. Network Context Pattern

The _RegionalNetworkContext provides a minimal Network-like interface for the inner coupling’s prepare() method. It implements: - graph: Regional graph - dynamics: For state name resolution - get_history(): Aggregates node history to regional history

This duck-typing pattern avoids creating full Network objects while maintaining API compatibility.