---
title: "Performance Tips for Delay Coupling"
subtitle: "Benchmarking History Management in Delay Differential Equations"
format:
html:
code-fold: false
toc: true
echo: false
embed-resources: true
fig-width: 8
out-width: "100%"
jupyter: python3
execute:
cache: true
---
Try this notebook interactively:
[Download .ipynb](https://github.com/virtual-twin/tvboptim/blob/main/docs/advanced/buffer_strategies.ipynb){.btn .btn-primary download="buffer_strategies.ipynb"}
[Download .qmd](buffer_strategies.qmd){.btn .btn-secondary download="buffer_strategies.qmd"}
[Open in Colab](https://colab.research.google.com/github/virtual-twin/tvboptim/blob/main/docs/advanced/buffer_strategies.ipynb){.btn .btn-warning target="_blank"}
## Introduction
Delay differential equations (DDEs) are computationally demanding because they require maintaining and accessing a history of past states. In brain network models, transmission delays arise from finite axonal conduction speeds, turning the system into a DDE where coupling terms depend on states at previous time points.
```{python}
#| output: false
#| echo: false
# Install dependencies if running in Google Colab
try:
import google.colab
print("Running in Google Colab - installing dependencies...")
!pip install -q tvboptim
print("✓ Dependencies installed!")
except ImportError:
pass # Not in Colab, assume dependencies are available
```
The core challenge is **history buffer management**: at each integration step, we must:
1. Store the current state in a buffer
2. Retrieve states from specific past times (per-connection delays)
3. Update the buffer for the next step
In JAX, how we implement this buffer management significantly impacts performance—and there is no universal best approach. The optimal strategy depends on:
- Network size (number of nodes)
- Time step size (determines buffer length)
- Whether you need gradients (reverse-mode autodiff has different memory access patterns)
- **Hardware**: CPU vs GPU have different memory bandwidth and access patterns
::: {.callout-note}
## Hardware Dependency
The benchmarks in this document were run on **CPU**. Results can vary significantly on different hardware, especially when switching to **GPU**. GPUs typically have much higher memory bandwidth, which can change the relative performance of buffer strategies. We recommend re-running benchmarks on your target hardware with the desired model configuration.
:::
This document benchmarks three buffer strategies across different configurations.
```{python}
#| output: false
#| code-fold: true
#| code-summary: "Environment Setup and Imports"
#| echo: true
import time
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from tvboptim.experimental.network_dynamics import Network, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import Generic2dOscillator
from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.utils import set_cache_path, cache
# Set cache path for benchmark results
set_cache_path("./buffer_benchmark")
```
## Buffer Strategies
TVB-Optim implements three strategies for managing the history buffer in delay coupling:
### Roll Strategy
```python
buffer = jnp.roll(buffer, shift=-1, axis=0)
buffer = buffer.at[-1].set(new_state)
```
The **roll** strategy shifts all buffer entries by one position and writes the new state at the end. This is conceptually simple but involves copying the entire buffer at each step.
### Circular Strategy
```python
write_idx = step % buffer_length
buffer = buffer.at[write_idx].set(new_state)
# Read indices computed with modular arithmetic
```
The **circular** strategy maintains a write pointer that wraps around. No data movement is required—only the pointer advances. Reading requires modular index computation.
### Preallocated Strategy
```python
# Before simulation: allocate full trajectory
trajectory = jnp.zeros((n_steps, n_states, n_nodes))
# During simulation: write directly
trajectory = trajectory.at[step].set(new_state)
```
The **preallocated** strategy allocates space for the entire trajectory upfront and writes sequentially. This avoids buffer management entirely but requires knowing the simulation length in advance and uses more memory.
::: {.callout-note}
The `buffer_strategy` parameter in `DelayedLinearCoupling` allows you to select the strategy:
```python
coupling = DelayedLinearCoupling(
incoming_states="V",
G=1.0,
buffer_strategy="circular" # or "roll" (default), "preallocated"
)
```
:::
## Benchmark Configuration
We benchmark all three strategies across:
- **Network sizes**: 10, 20, 50, 100, 200 nodes
- **Time steps (dt)**: 1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01 ms
Smaller dt values create larger history buffers (more past states to store for the same physical delay).
```{python}
#| echo: true
#| output: false
# Benchmark configuration
DT_VALUES = [1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01]
NETWORK_SIZES = [10, 20, 50, 100, 200]
FIXED_TIMESTEPS = 5_000 # Number of simulation steps
N_FORWARD_RUNS = 3 # Runs for forward timing
N_GRADIENT_RUNS = 3 # Runs for gradient timing
G = 0.001 # Coupling strength
MAX_DELAY = 50 # Maximum delay in ms
STRATEGIES = ["roll", "circular", "preallocated"]
```
```{python}
#| output: false
#| code-fold: true
#| code-summary: "Helper Functions"
#| echo: true
def create_network(graph, G, buffer_strategy):
"""Create a network with the specified buffer strategy."""
return Network(
dynamics=Generic2dOscillator(
a=-1.5,
tau=4.0,
b=-15.0,
c=0.0,
d=0.015,
e=3.0,
f=1.0,
g=0.0,
I=0.0,
INITIAL_STATE=(-0.098, -0.098),
VARIABLES_OF_INTEREST=("V", "W"),
),
coupling={
"delayed": DelayedLinearCoupling(
incoming_states="V",
G=G,
buffer_strategy=buffer_strategy,
)
},
graph=graph,
)
def make_loss_fn(model_fn, state):
"""Create a loss function for gradient benchmarking."""
def loss_fn(G_value):
state.coupling.delayed.G = G_value
result = model_fn(state)
return jnp.mean(result.ys[-50:, 0, :])
return loss_fn
def run_benchmark(n_regions, dt, n_timesteps, n_forward_runs=3, n_gradient_runs=3):
"""Run benchmark for all strategies at a specific configuration."""
# Create random graph
graph = DenseDelayGraph.random(n_regions, max_delay=MAX_DELAY, key=jax.random.key(0))
# Compute simulation parameters
simulation_length = n_timesteps * dt
max_delay = float(graph.delays.max())
t_offset = max_delay + dt
history_length = int(np.ceil(max_delay / dt)) + 1
results = {}
# Prepare all models
models = {}
states = {}
for strategy in STRATEGIES:
network = create_network(graph, G, strategy)
model_fn, state = prepare(
network,
Heun(),
t0=0.0 + t_offset,
t1=simulation_length + t_offset,
dt=dt,
)
models[strategy] = jax.jit(model_fn)
states[strategy] = state
# Warm-up (JIT compilation)
sim_results = {}
for strategy in STRATEGIES:
result = models[strategy](states[strategy])
jax.block_until_ready(result.ys)
sim_results[strategy] = result
# Forward benchmark
forward_times = {s: [] for s in STRATEGIES}
for _ in range(n_forward_runs):
for strategy in STRATEGIES:
t_start = time.perf_counter()
result = models[strategy](states[strategy])
jax.block_until_ready(result.ys)
t_end = time.perf_counter()
forward_times[strategy].append(t_end - t_start)
sim_results[strategy] = result
# Gradient benchmark
grad_fns = {}
for strategy in STRATEGIES:
loss_fn = make_loss_fn(models[strategy], states[strategy])
grad_fns[strategy] = jax.jit(jax.value_and_grad(loss_fn))
# Warm-up gradients
G_init = jnp.array(G)
grad_results = {}
for strategy in STRATEGIES:
val, grad = grad_fns[strategy](G_init)
jax.block_until_ready(grad)
grad_results[strategy] = (val, grad)
# Gradient timing
gradient_times = {s: [] for s in STRATEGIES}
for _ in range(n_gradient_runs):
for strategy in STRATEGIES:
t_start = time.perf_counter()
val, grad = grad_fns[strategy](G_init)
jax.block_until_ready(grad)
t_end = time.perf_counter()
gradient_times[strategy].append(t_end - t_start)
grad_results[strategy] = (val, grad)
# Correctness verification (use roll as reference)
V_ref = np.array(sim_results["roll"].ys[:, 0, :])
_, grad_ref = grad_results["roll"]
for strategy in STRATEGIES:
V = np.array(sim_results[strategy].ys[:, 0, :])
_, grad = grad_results[strategy]
corr_V = np.corrcoef(V_ref.flatten(), V.flatten())[0, 1]
rmse_V = np.sqrt(np.mean((V_ref - V) ** 2))
max_err_V = np.max(np.abs(V_ref - V))
grad_diff = abs(float(grad_ref) - float(grad))
results[strategy] = {
"forward_mean": np.mean(forward_times[strategy]),
"forward_std": np.std(forward_times[strategy]),
"gradient_mean": np.mean(gradient_times[strategy]),
"gradient_std": np.std(gradient_times[strategy]),
"correlation": corr_V,
"rmse": rmse_V,
"max_error": max_err_V,
"grad_diff": grad_diff,
"is_correct": corr_V > 0.99,
}
results["_meta"] = {
"history_length": history_length,
"simulation_length": simulation_length,
"max_delay": max_delay,
}
return results
```
## Running the Benchmark
The benchmark sweeps over all combinations of network size and time step, measuring both forward simulation and gradient computation times.
```{python}
#| echo: true
#| output: true
@cache("benchmark_sweep", redo=False)
def run_full_benchmark():
"""Run the complete benchmark sweep (cached)."""
all_results = {}
total_configs = len(DT_VALUES) * len(NETWORK_SIZES)
config_num = 0
for n_regions in NETWORK_SIZES:
all_results[n_regions] = {}
for dt in DT_VALUES:
config_num += 1
sim_length = FIXED_TIMESTEPS * dt
print(f"[{config_num}/{total_configs}] N={n_regions}, dt={dt}, T={sim_length:.1f}ms")
try:
result = run_benchmark(n_regions, dt, FIXED_TIMESTEPS, N_FORWARD_RUNS, N_GRADIENT_RUNS)
all_results[n_regions][dt] = result
except Exception as e:
print(f" ERROR: {e}")
all_results[n_regions][dt] = None
return all_results
all_results = run_full_benchmark()
```
## Results
```{python}
#| output: false
#| code-fold: true
#| code-summary: "Analyze Results"
# Create matrices for best strategy indices
strategy_list = STRATEGIES
best_forward = np.zeros((len(NETWORK_SIZES), len(DT_VALUES)), dtype=int)
best_gradient = np.zeros((len(NETWORK_SIZES), len(DT_VALUES)), dtype=int)
forward_speedups = {s: np.zeros((len(NETWORK_SIZES), len(DT_VALUES))) for s in STRATEGIES}
gradient_speedups = {s: np.zeros((len(NETWORK_SIZES), len(DT_VALUES))) for s in STRATEGIES}
for i, n_regions in enumerate(NETWORK_SIZES):
for j, dt in enumerate(DT_VALUES):
if all_results[n_regions][dt] is not None:
result = all_results[n_regions][dt]
# Forward: find fastest
fwd_times = [result[s]["forward_mean"] for s in strategy_list]
best_forward[i, j] = np.argmin(fwd_times)
ref_fwd = result["roll"]["forward_mean"]
for k, s in enumerate(strategy_list):
forward_speedups[s][i, j] = ref_fwd / result[s]["forward_mean"]
# Gradient: find fastest
grad_times = [result[s]["gradient_mean"] for s in strategy_list]
best_gradient[i, j] = np.argmin(grad_times)
ref_grad = result["roll"]["gradient_mean"]
for k, s in enumerate(strategy_list):
gradient_speedups[s][i, j] = ref_grad / result[s]["gradient_mean"]
else:
best_forward[i, j] = -1
best_gradient[i, j] = -1
```
```{python}
#| label: fig-benchmark
#| fig-cap: "**Buffer strategy benchmark results.** Top row: Best performing strategy for forward pass (left) and gradient computation (center), with speedup curves for the largest network (right). Bottom row: Ratio of gradient to forward computation time for each strategy, showing how gradient overhead varies across configurations."
#| code-fold: true
#| code-summary: "Visualization Code"
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
# Strategy visualization settings
colors = ['gold', 'steelblue', '#000000']
STRATEGY_CONFIG = {
"roll": {"color": colors[0], "marker": "s"},
"circular": {"color": colors[1], "marker": "o"},
"preallocated": {"color": colors[2], "marker": "^"},
}
fig, axes = plt.subplots(2, 3, figsize=(14, 9))
# Create custom colormap for strategies
strategy_colors = [STRATEGY_CONFIG[s]["color"] for s in strategy_list]
strategy_cmap = ListedColormap(strategy_colors)
legend_elements = [Patch(facecolor=STRATEGY_CONFIG[s]["color"], label=s) for s in strategy_list]
# ---- Row 1: Best Strategy Heatmaps ----
# Plot 1: Best strategy for forward pass
ax1 = axes[0, 0]
im1 = ax1.imshow(best_forward, aspect="auto", cmap=strategy_cmap, vmin=0, vmax=2)
ax1.set_xticks(range(len(DT_VALUES)))
ax1.set_xticklabels([str(dt) for dt in DT_VALUES], rotation=45, ha="right")
ax1.set_yticks(range(len(NETWORK_SIZES)))
ax1.set_yticklabels([str(n) for n in NETWORK_SIZES])
ax1.set_xlabel("dt (ms)")
ax1.set_ylabel("Network Size (N)")
ax1.set_title("Best Strategy: Forward Pass")
ax1.legend(handles=legend_elements, loc="upper right", fontsize=7)
# Plot 2: Best strategy for gradient
ax2 = axes[0, 1]
im2 = ax2.imshow(best_gradient, aspect="auto", cmap=strategy_cmap, vmin=0, vmax=2)
ax2.set_xticks(range(len(DT_VALUES)))
ax2.set_xticklabels([str(dt) for dt in DT_VALUES], rotation=45, ha="right")
ax2.set_yticks(range(len(NETWORK_SIZES)))
ax2.set_yticklabels([str(n) for n in NETWORK_SIZES])
ax2.set_xlabel("dt (ms)")
ax2.set_ylabel("Network Size (N)")
ax2.set_title("Best Strategy: Gradient")
ax2.legend(handles=legend_elements, loc="upper right", fontsize=7)
# Plot 3: Summary line plot - speedup vs dt for largest network
ax3 = axes[0, 2]
n_largest = NETWORK_SIZES[-1]
for strategy in strategy_list:
fwd_spd = []
grad_spd = []
dts = []
for j, dt in enumerate(DT_VALUES):
if all_results[n_largest][dt] is not None:
dts.append(dt)
fwd_spd.append(forward_speedups[strategy][-1, j])
grad_spd.append(gradient_speedups[strategy][-1, j])
ax3.plot(dts, fwd_spd, marker=STRATEGY_CONFIG[strategy]["marker"],
color=STRATEGY_CONFIG[strategy]["color"], label=f"{strategy} (fwd)",
lw=2, markersize=6, linestyle="-")
ax3.plot(dts, grad_spd, marker=STRATEGY_CONFIG[strategy]["marker"],
color=STRATEGY_CONFIG[strategy]["color"], label=f"{strategy} (grad)",
lw=1.5, markersize=5, linestyle="--", alpha=0.7)
ax3.axhline(y=1.0, color="red", linestyle=":", alpha=0.5)
ax3.set_xscale("log")
ax3.set_xlabel("dt (ms)")
ax3.set_ylabel("Speedup vs roll")
ax3.set_yscale("log")
ax3.set_title(f"Speedup vs dt (N={n_largest})")
ax3.legend(fontsize=6, ncol=2)
ax3.grid(alpha=0.3)
ax3.invert_xaxis()
# ---- Row 2: Gradient/Forward ratio for each strategy ----
for idx, strategy in enumerate(strategy_list):
ax = axes[1, idx]
# Compute forward and gradient times matrices
forward_times = np.zeros((len(NETWORK_SIZES), len(DT_VALUES)))
gradient_times = np.zeros((len(NETWORK_SIZES), len(DT_VALUES)))
for i, n_regions in enumerate(NETWORK_SIZES):
for j, dt in enumerate(DT_VALUES):
if all_results[n_regions][dt] is not None:
forward_times[i, j] = all_results[n_regions][dt][strategy]["forward_mean"]
gradient_times[i, j] = all_results[n_regions][dt][strategy]["gradient_mean"]
else:
forward_times[i, j] = np.nan
gradient_times[i, j] = np.nan
# Show gradient/forward ratio
ratio = gradient_times / forward_times
im = ax.imshow(ratio, aspect="auto", cmap="cividis_r", vmin=1, vmax=20)
ax.set_xticks(range(len(DT_VALUES)))
ax.set_xticklabels([str(dt) for dt in DT_VALUES], rotation=45, ha="right")
ax.set_yticks(range(len(NETWORK_SIZES)))
ax.set_yticklabels([str(n) for n in NETWORK_SIZES])
ax.set_xlabel("dt (ms)")
ax.set_ylabel("Network Size (N)")
ax.set_title(f"{strategy}: Gradient/Forward Ratio")
plt.colorbar(im, ax=ax, label="Ratio (grad/fwd)")
# Add text annotations
for i in range(len(NETWORK_SIZES)):
for j in range(len(DT_VALUES)):
if not np.isnan(ratio[i, j]):
color = "white" if ratio[i, j] > 10 else "black"
ax.text(j, i, f"{ratio[i, j]:.1f}", ha="center", va="center",
fontsize=6, color=color)
plt.suptitle(
f"Buffer Strategy Benchmark: N x dt Sweep\n"
rf"timesteps={FIXED_TIMESTEPS}, $\tau_\max=${MAX_DELAY}ms , G={G}",
fontsize=12,
fontweight="bold",
)
plt.tight_layout()
plt.show()
```
## Correctness Verification
All strategies produce numerically equivalent results (correlation > 0.99 with roll reference):
```{python}
#| echo: true
all_correct = True
for n_regions in NETWORK_SIZES:
for dt in DT_VALUES:
if all_results[n_regions][dt] is not None:
result = all_results[n_regions][dt]
for strategy in STRATEGIES:
if not result[strategy]["is_correct"]:
print(f"FAIL: N={n_regions}, dt={dt}, {strategy}: corr={result[strategy]['correlation']:.8f}")
all_correct = False
if all_correct:
print("All configurations PASSED correctness verification.")
```
## Summary Statistics
```{python}
#| echo: true
# Count wins for each strategy
forward_wins = {s: 0 for s in strategy_list}
gradient_wins = {s: 0 for s in strategy_list}
total_valid = 0
for i, n_regions in enumerate(NETWORK_SIZES):
for j, dt in enumerate(DT_VALUES):
if all_results[n_regions][dt] is not None:
total_valid += 1
forward_wins[strategy_list[best_forward[i, j]]] += 1
gradient_wins[strategy_list[best_gradient[i, j]]] += 1
print(f"Total configurations tested: {total_valid}")
print(f"\nForward Pass Wins:")
for s in strategy_list:
pct = 100 * forward_wins[s] / total_valid if total_valid > 0 else 0
print(f" {s:<15}: {forward_wins[s]:>3} ({pct:.1f}%)")
print(f"\nGradient Wins:")
for s in strategy_list:
pct = 100 * gradient_wins[s] / total_valid if total_valid > 0 else 0
print(f" {s:<15}: {gradient_wins[s]:>3} ({pct:.1f}%)")
```
## Key Findings
1. **No universal winner**: The best strategy depends on network size, time step, and whether gradients are needed.
2. **Forward vs Gradient trade-offs**: Strategies that perform well for forward simulation may not be optimal for gradient computation due to different memory access patterns in reverse-mode autodiff.
3. **Buffer size matters**: Smaller dt creates larger buffers, which changes the relative performance of strategies.
4. **Hardware matters**: These results are for CPU. GPU execution can yield different results due to higher memory bandwidth and different memory access characteristics.
5. **Practical guidance**:
- For **forward-only** simulation: circular or preallocated often win
- For **gradient-based optimization**: test your specific configuration, as the optimal choice varies