Performance Tips for Delay Coupling

Benchmarking History Management in Delay Differential Equations

Try this notebook interactively:

Download .ipynb Download .qmd Open in Colab

Introduction

Delay differential equations (DDEs) are computationally demanding because they require maintaining and accessing a history of past states. In brain network models, transmission delays arise from finite axonal conduction speeds, turning the system into a DDE where coupling terms depend on states at previous time points.

The core challenge is history buffer management: at each integration step, we must:

  1. Store the current state in a buffer
  2. Retrieve states from specific past times (per-connection delays)
  3. Update the buffer for the next step

In JAX, how we implement this buffer management significantly impacts performance—and there is no universal best approach. The optimal strategy depends on:

  • Network size (number of nodes)
  • Time step size (determines buffer length)
  • Whether you need gradients (reverse-mode autodiff has different memory access patterns)
  • Hardware: CPU vs GPU have different memory bandwidth and access patterns
NoteHardware Dependency

The benchmarks in this document were run on CPU. Results can vary significantly on different hardware, especially when switching to GPU. GPUs typically have much higher memory bandwidth, which can change the relative performance of buffer strategies. We recommend re-running benchmarks on your target hardware with the desired model configuration.

This document benchmarks three buffer strategies across different configurations.

Environment Setup and Imports
import time
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from tvboptim.experimental.network_dynamics import Network, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import Generic2dOscillator
from tvboptim.experimental.network_dynamics.coupling import DelayedLinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseDelayGraph
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.experimental.network_dynamics.solvers import Heun
from tvboptim.utils import set_cache_path, cache

# Set cache path for benchmark results
set_cache_path("./buffer_benchmark")

Buffer Strategies

TVB-Optim implements three strategies for managing the history buffer in delay coupling:

Roll Strategy

buffer = jnp.roll(buffer, shift=-1, axis=0)
buffer = buffer.at[-1].set(new_state)

The roll strategy shifts all buffer entries by one position and writes the new state at the end. This is conceptually simple but involves copying the entire buffer at each step.

Circular Strategy

write_idx = step % buffer_length
buffer = buffer.at[write_idx].set(new_state)
# Read indices computed with modular arithmetic

The circular strategy maintains a write pointer that wraps around. No data movement is required—only the pointer advances. Reading requires modular index computation.

Preallocated Strategy

# Before simulation: allocate full trajectory
trajectory = jnp.zeros((n_steps, n_states, n_nodes))

# During simulation: write directly
trajectory = trajectory.at[step].set(new_state)

The preallocated strategy allocates space for the entire trajectory upfront and writes sequentially. This avoids buffer management entirely but requires knowing the simulation length in advance and uses more memory.

Note

The buffer_strategy parameter in DelayedLinearCoupling allows you to select the strategy:

coupling = DelayedLinearCoupling(
    incoming_states="V",
    G=1.0,
    buffer_strategy="circular"  # or "roll" (default), "preallocated"
)

Benchmark Configuration

We benchmark all three strategies across:

  • Network sizes: 10, 20, 50, 100, 200 nodes
  • Time steps (dt): 1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01 ms

Smaller dt values create larger history buffers (more past states to store for the same physical delay).

# Benchmark configuration
DT_VALUES = [1.0, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01]
NETWORK_SIZES = [10, 20, 50, 100, 200]
FIXED_TIMESTEPS = 5_000  # Number of simulation steps
N_FORWARD_RUNS = 3       # Runs for forward timing
N_GRADIENT_RUNS = 3      # Runs for gradient timing
G = 0.001                # Coupling strength
MAX_DELAY = 50           # Maximum delay in ms
STRATEGIES = ["roll", "circular", "preallocated"]
Helper Functions
def create_network(graph, G, buffer_strategy):
    """Create a network with the specified buffer strategy."""
    return Network(
        dynamics=Generic2dOscillator(
            a=-1.5,
            tau=4.0,
            b=-15.0,
            c=0.0,
            d=0.015,
            e=3.0,
            f=1.0,
            g=0.0,
            I=0.0,
            INITIAL_STATE=(-0.098, -0.098),
            VARIABLES_OF_INTEREST=("V", "W"),
        ),
        coupling={
            "delayed": DelayedLinearCoupling(
                incoming_states="V",
                G=G,
                buffer_strategy=buffer_strategy,
            )
        },
        graph=graph,
    )


def make_loss_fn(model_fn, state):
    """Create a loss function for gradient benchmarking."""
    def loss_fn(G_value):
        state.coupling.delayed.G = G_value
        result = model_fn(state)
        return jnp.mean(result.ys[-50:, 0, :])
    return loss_fn


def run_benchmark(n_regions, dt, n_timesteps, n_forward_runs=3, n_gradient_runs=3):
    """Run benchmark for all strategies at a specific configuration."""
    # Create random graph
    graph = DenseDelayGraph.random(n_regions, max_delay=MAX_DELAY, key=jax.random.key(0))

    # Compute simulation parameters
    simulation_length = n_timesteps * dt
    max_delay = float(graph.delays.max())
    t_offset = max_delay + dt
    history_length = int(np.ceil(max_delay / dt)) + 1

    results = {}

    # Prepare all models
    models = {}
    states = {}
    for strategy in STRATEGIES:
        network = create_network(graph, G, strategy)
        model_fn, state = prepare(
            network,
            Heun(),
            t0=0.0 + t_offset,
            t1=simulation_length + t_offset,
            dt=dt,
        )
        models[strategy] = jax.jit(model_fn)
        states[strategy] = state

    # Warm-up (JIT compilation)
    sim_results = {}
    for strategy in STRATEGIES:
        result = models[strategy](states[strategy])
        jax.block_until_ready(result.ys)
        sim_results[strategy] = result

    # Forward benchmark
    forward_times = {s: [] for s in STRATEGIES}
    for _ in range(n_forward_runs):
        for strategy in STRATEGIES:
            t_start = time.perf_counter()
            result = models[strategy](states[strategy])
            jax.block_until_ready(result.ys)
            t_end = time.perf_counter()
            forward_times[strategy].append(t_end - t_start)
            sim_results[strategy] = result

    # Gradient benchmark
    grad_fns = {}
    for strategy in STRATEGIES:
        loss_fn = make_loss_fn(models[strategy], states[strategy])
        grad_fns[strategy] = jax.jit(jax.value_and_grad(loss_fn))

    # Warm-up gradients
    G_init = jnp.array(G)
    grad_results = {}
    for strategy in STRATEGIES:
        val, grad = grad_fns[strategy](G_init)
        jax.block_until_ready(grad)
        grad_results[strategy] = (val, grad)

    # Gradient timing
    gradient_times = {s: [] for s in STRATEGIES}
    for _ in range(n_gradient_runs):
        for strategy in STRATEGIES:
            t_start = time.perf_counter()
            val, grad = grad_fns[strategy](G_init)
            jax.block_until_ready(grad)
            t_end = time.perf_counter()
            gradient_times[strategy].append(t_end - t_start)
            grad_results[strategy] = (val, grad)

    # Correctness verification (use roll as reference)
    V_ref = np.array(sim_results["roll"].ys[:, 0, :])
    _, grad_ref = grad_results["roll"]

    for strategy in STRATEGIES:
        V = np.array(sim_results[strategy].ys[:, 0, :])
        _, grad = grad_results[strategy]

        corr_V = np.corrcoef(V_ref.flatten(), V.flatten())[0, 1]
        rmse_V = np.sqrt(np.mean((V_ref - V) ** 2))
        max_err_V = np.max(np.abs(V_ref - V))
        grad_diff = abs(float(grad_ref) - float(grad))

        results[strategy] = {
            "forward_mean": np.mean(forward_times[strategy]),
            "forward_std": np.std(forward_times[strategy]),
            "gradient_mean": np.mean(gradient_times[strategy]),
            "gradient_std": np.std(gradient_times[strategy]),
            "correlation": corr_V,
            "rmse": rmse_V,
            "max_error": max_err_V,
            "grad_diff": grad_diff,
            "is_correct": corr_V > 0.99,
        }

    results["_meta"] = {
        "history_length": history_length,
        "simulation_length": simulation_length,
        "max_delay": max_delay,
    }

    return results

Running the Benchmark

The benchmark sweeps over all combinations of network size and time step, measuring both forward simulation and gradient computation times.

@cache("benchmark_sweep", redo=False)
def run_full_benchmark():
    """Run the complete benchmark sweep (cached)."""
    all_results = {}
    total_configs = len(DT_VALUES) * len(NETWORK_SIZES)
    config_num = 0

    for n_regions in NETWORK_SIZES:
        all_results[n_regions] = {}
        for dt in DT_VALUES:
            config_num += 1
            sim_length = FIXED_TIMESTEPS * dt
            print(f"[{config_num}/{total_configs}] N={n_regions}, dt={dt}, T={sim_length:.1f}ms")

            try:
                result = run_benchmark(n_regions, dt, FIXED_TIMESTEPS, N_FORWARD_RUNS, N_GRADIENT_RUNS)
                all_results[n_regions][dt] = result
            except Exception as e:
                print(f"  ERROR: {e}")
                all_results[n_regions][dt] = None

    return all_results

all_results = run_full_benchmark()
Loading benchmark_sweep from cache, last modified 2025-12-17 15:27:36.942677

Results

Figure 1: Buffer strategy benchmark results. Top row: Best performing strategy for forward pass (left) and gradient computation (center), with speedup curves for the largest network (right). Bottom row: Ratio of gradient to forward computation time for each strategy, showing how gradient overhead varies across configurations.

Correctness Verification

All strategies produce numerically equivalent results (correlation > 0.99 with roll reference):

all_correct = True
for n_regions in NETWORK_SIZES:
    for dt in DT_VALUES:
        if all_results[n_regions][dt] is not None:
            result = all_results[n_regions][dt]
            for strategy in STRATEGIES:
                if not result[strategy]["is_correct"]:
                    print(f"FAIL: N={n_regions}, dt={dt}, {strategy}: corr={result[strategy]['correlation']:.8f}")
                    all_correct = False

if all_correct:
    print("All configurations PASSED correctness verification.")
All configurations PASSED correctness verification.

Summary Statistics

# Count wins for each strategy
forward_wins = {s: 0 for s in strategy_list}
gradient_wins = {s: 0 for s in strategy_list}
total_valid = 0

for i, n_regions in enumerate(NETWORK_SIZES):
    for j, dt in enumerate(DT_VALUES):
        if all_results[n_regions][dt] is not None:
            total_valid += 1
            forward_wins[strategy_list[best_forward[i, j]]] += 1
            gradient_wins[strategy_list[best_gradient[i, j]]] += 1

print(f"Total configurations tested: {total_valid}")
print(f"\nForward Pass Wins:")
for s in strategy_list:
    pct = 100 * forward_wins[s] / total_valid if total_valid > 0 else 0
    print(f"  {s:<15}: {forward_wins[s]:>3} ({pct:.1f}%)")

print(f"\nGradient Wins:")
for s in strategy_list:
    pct = 100 * gradient_wins[s] / total_valid if total_valid > 0 else 0
    print(f"  {s:<15}: {gradient_wins[s]:>3} ({pct:.1f}%)")
Total configurations tested: 35

Forward Pass Wins:
  roll           :   2 (5.7%)
  circular       :   6 (17.1%)
  preallocated   :  27 (77.1%)

Gradient Wins:
  roll           :  15 (42.9%)
  circular       :  20 (57.1%)
  preallocated   :   0 (0.0%)

Key Findings

  1. No universal winner: The best strategy depends on network size, time step, and whether gradients are needed.

  2. Forward vs Gradient trade-offs: Strategies that perform well for forward simulation may not be optimal for gradient computation due to different memory access patterns in reverse-mode autodiff.

  3. Buffer size matters: Smaller dt creates larger buffers, which changes the relative performance of strategies.

  4. Hardware matters: These results are for CPU. GPU execution can yield different results due to higher memory bandwidth and different memory access characteristics.

  5. Practical guidance:

    • For forward-only simulation: circular or preallocated often win
    • For gradient-based optimization: test your specific configuration, as the optimal choice varies