---
title: "Extending Coupling - Surface Simulations"
subtitle: "Recreating TVB's Surface Simulations with Subspace Coupling"
format:
html:
code-fold: false
toc: true
toc-depth: 3
fig-width: 8
out-width: "100%"
jupyter: python3
execute:
cache: true
---
Try this notebook interactively:
[Download .ipynb](https://github.com/virtual-twin/tvboptim/blob/main/docs/advanced/subspace_coupling.ipynb){.btn .btn-primary download="subspace_coupling.ipynb"}
[Download .qmd](subspace_coupling.qmd){.btn .btn-secondary download="subspace_coupling.qmd"}
[Open in Colab](https://colab.research.google.com/github/virtual-twin/tvboptim/blob/main/docs/advanced/subspace_coupling.ipynb){.btn .btn-warning target="_blank"}
# Introduction
Subspace coupling enables **hierarchical network interactions** where fine-grained nodes are organized into coarser regions. This pattern is fundamental in brain network modeling for cortical surface simulations, where thousands of vertices belong to anatomical parcels that communicate via long-range white matter tracts.
The key idea: instead of coupling all nodes directly, we **aggregate node states to regions, apply coupling at the regional level, then distribute results back to nodes**. This three-stage pattern is implemented through the standard coupling API: `prepare()`, `compute()`, and `update_state()`.
## Use Case: Surface Simulations
In The Virtual Brain (TVB), surface simulations model cortical activity at the vertex level (10 thousands of vertices) while maintaining regional connectivity via structural connectomes (hundreds of regions). This creates a two-level hierarchy:
- **Node level**: Fine-grained cortical vertices with local connectivity
- **Regional level**: Coarse anatomical parcels with long-range delayed connectivity
```{python}
#| output: false
#| echo: false
# Install dependencies if running in Google Colab
try:
import google.colab
print("Running in Google Colab - installing dependencies...")
!pip install -q tvboptim
print("✓ Dependencies installed!")
except ImportError:
pass # Not in Colab, assume dependencies are available
```
```{python}
#| code-fold: true
#| code-summary: "Environment Setup and Imports"
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
import time
from tvboptim.experimental.network_dynamics import Network, solve
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling.linear import FastLinearCoupling, DelayedLinearCoupling
from tvboptim.experimental.network_dynamics.coupling.subspace import SubspaceCoupling
from tvboptim.experimental.network_dynamics.graph import SparseGraph, DenseDelayGraph
from tvboptim.experimental.network_dynamics.solvers import Euler
```
# The Subspace Coupling Pattern
## Conceptual Overview
Subspace coupling operates in three stages:
```
1. AGGREGATE: [n_states, n_nodes] → [n_states, n_regions]
2. COUPLE: Apply coupling at regional level
3. DISTRIBUTE: [n_coupling, n_regions] → [n_coupling, n_nodes]
```
**Stage 1 (Aggregate)**: Average node states within each region
$$
s_r = \frac{1}{|R_r|} \sum_{i \in R_r} s_i
$$
**Stage 2 (Couple)**: Apply inner coupling on regional graph
$$
c_r = \sum_{r'} w_{rr'} f(s_{r'}(t - \tau_{rr'}))
$$
**Stage 3 (Distribute)**: Broadcast regional coupling to constituent nodes
$$
c_i = c_{\text{region}(i)}
$$
## Implementation: The Coupling API
Let's see how this pattern maps to the three-phase coupling API. We'll show conceptual code snippets that capture the essence of `SubspaceCoupling`.
### Phase 1: `prepare()` - Precompute Regional Structures
```python
def prepare(self, network, dt, t0, t1):
"""Precompute aggregation matrices and prepare inner coupling."""
# Build normalized aggregation matrix [n_nodes, n_regions]
# Each node contributes 1/|region| to its region's mean
region_one_hot = jnp.eye(self.n_regions)[self.region_mapping]
region_counts = jnp.sum(region_one_hot, axis=0)
region_one_hot_normalized = region_one_hot / region_counts[None, :]
# Create regional network context for inner coupling
# This provides a Network-like interface for the regional graph
regional_context = self._create_regional_context(network, ...)
# Prepare inner coupling (e.g., DelayedLinearCoupling on regional graph)
inner_data, inner_state = self.inner_coupling.prepare(
regional_context, dt, t0, t1
)
# Compute initial aggregated state for caching
initial_regional_state = self.aggregate(
network.initial_state, region_one_hot_normalized
)
return coupling_data, coupling_state
```
**Key insight**: We precompute the aggregation matrix once, avoiding repeated computation. For sparse mappings (many nodes, few regions), we use sparse matrices.
### Phase 2: `compute()` - Three-Stage Computation
```python
def compute(self, t, state, coupling_data, coupling_state, params, graph):
"""Aggregate → couple → distribute."""
# Stage 1: Use cached aggregated regional state
# (computed in previous update_state, avoids redundant aggregation)
regional_state = coupling_state.cached_regional_state
# Shape: [n_states, n_regions]
# Stage 2: Compute coupling at regional level
regional_coupling = self.inner_coupling.compute(
t,
regional_state,
coupling_data.inner_data,
coupling_state.inner_state,
self.inner_coupling.params,
self.regional_graph # Regional connectivity
)
# Shape: [n_coupling_inputs, n_regions]
# Stage 3: Distribute regional coupling to nodes (broadcast)
node_coupling = self.distribute(
regional_coupling, coupling_data.region_mapping
)
# Shape: [n_coupling_inputs, n_nodes]
return node_coupling
```
**Performance optimization**: We cache the aggregated state from the previous timestep, avoiding redundant aggregation in `compute()`.
### Phase 3: `update_state()` - Update and Cache
```python
def update_state(self, coupling_data, coupling_state, new_state):
"""Update inner coupling state and cache new aggregated state."""
# Aggregate new node state to regional state
regional_state = self.aggregate(
new_state, coupling_data.region_one_hot_normalized
)
# Update inner coupling (e.g., delay buffer with regional states)
new_inner_state = self.inner_coupling.update_state(
coupling_data.inner_data,
coupling_state.inner_state,
regional_state # Regional state for delay buffer
)
# Cache aggregated state for next compute()
return Bunch(
inner_state=new_inner_state,
cached_regional_state=regional_state
)
```
**Temporal optimization**: The aggregated state computed here is reused in the next `compute()` call, eliminating duplicate aggregation operations.
## Aggregate and Distribute Methods
These customizable methods implement the projection between node and regional spaces:
```python
def aggregate(self, node_state, coupling_data):
"""Aggregate node states to regional states (default: mean)."""
# node_state: [n_states, n_nodes]
# region_one_hot_normalized: [n_nodes, n_regions]
regional_state = node_state @ coupling_data.region_one_hot_normalized
# Shape: [n_states, n_regions]
return regional_state
def distribute(self, regional_coupling, coupling_data):
"""Distribute regional coupling to nodes (default: broadcast)."""
# regional_coupling: [n_coupling_inputs, n_regions]
# region_mapping: [n_nodes] with region IDs
node_coupling = regional_coupling[:, coupling_data.region_mapping]
# Shape: [n_coupling_inputs, n_nodes]
return node_coupling
```
**Customization**: Override these methods for alternative strategies (weighted aggregation, scaled distribution, etc.).
# Practical Example: Mixed Coupling
Let's create a realistically sized surface simulation with both local and regional coupling:
```{python}
# Network dimensions
n_nodes = 16000 # Cortical vertices
n_regions = 76 # Brain regions
t0, t1, dt = 0.0, 1000.0, 1.0
# Regional connectivity: Structural connectome with delays
region_mapping = jax.random.randint(
jax.random.key(42), (n_nodes,), 0, n_regions
)
regional_graph = DenseDelayGraph.random(
n_nodes=n_regions,
sparsity=0.8, # 80% of connections present
max_delay=50.0, # ~150mm at 3 m/s
key=jax.random.key(0)
)
# Local connectivity: Sparse short-range connections
node_graph = SparseGraph.random(
n_nodes=n_nodes,
sparsity=0.000366, # 0.0366% connectivity (typical density, depending on kernel)
key=jax.random.key(1)
)
print(f"Local graph: {node_graph.nnz:,} edges (sparse)")
print(f"Regional graph: {n_regions}×{n_regions} (dense with delays)")
```
## Network Construction
Now create a network with two coupling types:
1. **Instantaneous local coupling**: Fast connections between nearby vertices
2. **Delayed regional coupling**: Long-range connections between brain regions
```{python}
# Local instantaneous coupling
coupling_instant = FastLinearCoupling(local_states='S', G=0.2)
# Regional delayed coupling via subspace
coupling_delayed = SubspaceCoupling(
inner_coupling=DelayedLinearCoupling(incoming_states='S', G=0.05),
region_mapping=region_mapping,
regional_graph=regional_graph,
)
# Multi-coupling network
network = Network(
dynamics=ReducedWongWang(I_o = 0.1),
coupling={
'instant': coupling_instant, # Local vertices
'delayed': coupling_delayed # Regional subspace
},
graph=node_graph
)
print(network)
```
**Key point**: The dynamics model (`ReducedWongWang`) declares two coupling inputs (`instant` and `delayed`). The network provides both through named couplings operating at different spatial scales.
## Simulation
```{python}
start = time.time()
result = solve(network, Euler(), t0=t0, t1=t1, dt=dt)
elapsed = time.time() - start
print(f"\nSimulation: {elapsed:.2f} seconds")
print(f"Result shape: {result.ys.shape}")
print(f"Time points: [{result.ts[0]:.1f}, {result.ts[-1]:.1f}] ms")
```
# Network Inspection
The network printer reveals the hierarchical structure:
```{python}
from tvboptim.experimental.network_dynamics.utils.printer import print_network
print_network(network)
```
**Notice**:
- Coupling 1 (`instant`): Operates on node graph (16,000 nodes, sparse)
- Coupling 2 (`delayed`): Shows regional structure (76 regions, aggregation/distribution)
- Max delay: ~50 ms for long-range regional connections
# History Management for Regional Coupling
When delayed regional coupling is used, the system needs history buffers at the regional level. The `_RegionalNetworkContext` handles this by aggregating node-level history to regional history.
## Continuing Simulations
If you set custom history on the network (e.g., continuing from a previous simulation), the regional coupling automatically aggregates it:
```{python}
# Run initial simulation
sim1 = solve(network, Euler(), t0=0.0, t1=500.0, dt=1.0)
# Set as history and continue
network.update_history(sim1)
sim2 = solve(network, Euler(), t0=500.0, t1=1000.0, dt=1.0)
print(f"First simulation: {sim1.ys.shape}")
print(f"Second simulation: {sim2.ys.shape}")
```
The regional coupling's `get_history()` method aggregates the node-level history to create appropriate regional delay buffers. This happens transparently via the shared history extraction utility.
## Visualize Continued Simulation
```{python}
#| code-fold: true
#| code-summary: "Visualization: Continuity at Simulation Boundary"
fig, axes = plt.subplots(2, 1, figsize=(8, 4.5), dpi = 200)
# Sample 100 vertices
sample_indices = jnp.linspace(0, n_nodes-1, 100, dtype=int)
# Vertex time series: sim1 and sim2
axes[0].plot(sim1.ts, sim1.ys[:, 0, sample_indices],
alpha=0.3, linewidth=0.5, color='steelblue', label='Sim 1')
axes[0].plot(sim2.ts, sim2.ys[:, 0, sample_indices],
alpha=0.3, linewidth=0.5, color='coral', label='Sim 2')
axes[0].axvline(500, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[0].set_ylabel('S (synaptic gating)')
axes[0].set_title(f'Cortical Activity: {len(sample_indices)} Vertices (Continued Simulation)')
axes[0].grid(True, alpha=0.3)
# Add zoom inset for vertices
axins0 = inset_axes(axes[0], width="30%", height="50%", loc='lower left',
bbox_to_anchor=(0.55, 0.1, 1, 1), bbox_transform=axes[0].transAxes)
axins0.plot(sim1.ts, sim1.ys[:, 0, sample_indices[:]],
alpha=0.8, linewidth=1, color='steelblue')
axins0.plot(sim2.ts, sim2.ys[:, 0, sample_indices[:]],
alpha=0.8, linewidth=1, color='coral')
axins0.axvline(500, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
axins0.set_xlim(480, 520)
axins0.set_ylim(sim1.ys[480:520, 0, sample_indices[:]].min() - 0.01,
sim1.ys[480:520, 0, sample_indices[:]].max() + 0.01)
axins0.grid(True, alpha=0.3, linewidth=0.5)
axins0.tick_params(labelsize=7)
mark_inset(axes[0], axins0, loc1=2, loc2=4, fc="none", ec="0.5", linestyle='--', linewidth=0.5)
# Mean regional activity
regional_activity_sim1 = []
regional_activity_sim2 = []
for r in range(n_regions):
mask = region_mapping == r
regional_activity_sim1.append(jnp.mean(sim1.ys[:, 0, mask], axis=1))
regional_activity_sim2.append(jnp.mean(sim2.ys[:, 0, mask], axis=1))
regional_activity_sim1 = jnp.array(regional_activity_sim1).T # [time, n_regions]
regional_activity_sim2 = jnp.array(regional_activity_sim2).T
# Plot first 10 regions
axes[1].plot(sim1.ts, regional_activity_sim1[:, :10],
alpha=0.7, linewidth=1.5, color='steelblue')
axes[1].plot(sim2.ts, regional_activity_sim2[:, :10],
alpha=0.7, linewidth=1.5, color='coral')
axes[1].axvline(500, color='black', linestyle='--', linewidth=1, alpha=0.5)
axes[1].set_xlabel('Time [ms]')
axes[1].set_ylabel('Mean S per region')
axes[1].set_title('Regional Activity: 10 Brain Regions (Continued Simulation)')
axes[1].grid(True, alpha=0.3)
# Add zoom inset for regional activity
axins1 = inset_axes(axes[1], width="30%", height="50%", loc='lower left',
bbox_to_anchor=(0.55, 0.1, 1, 1), bbox_transform=axes[1].transAxes)
axins1.plot(sim1.ts, regional_activity_sim1[:, :10],
alpha=0.8, linewidth=1.5, color='steelblue')
axins1.plot(sim2.ts, regional_activity_sim2[:, :10],
alpha=0.8, linewidth=1.5, color='coral')
axins1.axvline(500, color='black', linestyle='--', linewidth=0.5, alpha=0.5)
axins1.set_xlim(480, 520)
axins1.set_ylim(regional_activity_sim1[480:520, 0:10].min() - 0.005,
regional_activity_sim1[480:520, 0:10].max() + 0.005)
axins1.grid(True, alpha=0.3, linewidth=0.5)
axins1.tick_params(labelsize=7)
mark_inset(axes[1], axins1, loc1=2, loc2=4, fc="none", ec="0.5", linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()
```
# Key Implementation Insights
## 1. Hierarchical State Management
The coupling maintains two state representations:
- **Node states**: `[n_states, n_nodes]` - full cortical surface
- **Regional states**: `[n_states, n_regions]` - aggregated parcels
Aggregation happens via matrix multiplication with precomputed sparse/dense matrices.
## 2. Performance Optimizations
**Temporal caching**: Avoid redundant aggregation by caching the aggregated state from `update_state()` for use in next `compute()`.
**Spatial caching**: Precompute aggregation matrices in `prepare()` rather than rebuilding each timestep.
**Sparse operations**: Use BCOO sparse format when `n_nodes >> n_regions` (enabled by default).
## 3. Nested Coupling Architecture
`SubspaceCoupling` wraps any inner coupling (instantaneous or delayed). The inner coupling sees only the regional graph, unaware it's part of a hierarchical system. This composition enables:
- Local + regional coupling combinations
- Delayed regional coupling with instantaneous local coupling
- Multiple regional couplings with different parameters
## 4. Network Context Pattern
The `_RegionalNetworkContext` provides a minimal Network-like interface for the inner coupling's `prepare()` method. It implements:
- `graph`: Regional graph
- `dynamics`: For state name resolution
- `get_history()`: Aggregates node history to regional history
This duck-typing pattern avoids creating full Network objects while maintaining API compatibility.