Axes and Spaces: Systematic Parameter Exploration in TVB-Optim

Introduction & Overview

The Axes and Spaces system is a generic parameter exploration framework for JAX-based models. It operates on JAX pytrees — any nested container JAX can traverse: plain dicts, lists, dataclasses, NamedTuples, or the config objects returned by prepare. Replace any leaf with an axis, pass the container to Space, and you get a sequence of fully-resolved parameter combinations ready to run.

The examples in this section use plain dicts to keep the mechanics visible. The execution section applies the same system to a real brain network model config.

Traditional parameter exploration means writing custom loops and managing array indexing by hand. The axis system handles both.

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from tvboptim.types.spaces import Space, GridAxis, UniformAxis, DataAxis, NumPyroAxis

# Any pytree works — here a plain dict with two axes and one fixed value
preview_state = {
    'coupling_strength': GridAxis(0.0, 2.0, 4),
    'noise_level': UniformAxis(0.01, 0.1, 3),
    'fixed_param': 42.0
}

space = Space(preview_state, mode='product', key=jax.random.key(123))
print(f"Parameter space: {len(space)} combinations")

for i, params in enumerate(space):
    if i >= 5: break
    print(f"Combination {i}: coupling={params['coupling_strength']:.2f}, "
          f"noise={params['noise_level']:.4f}, fixed={params['fixed_param']}")
Parameter space: 12 combinations
Combination 0: coupling=0.00, noise=0.0265, fixed=42.0
Combination 1: coupling=0.00, noise=0.0876, fixed=42.0
Combination 2: coupling=0.00, noise=0.0204, fixed=42.0
Combination 3: coupling=0.67, noise=0.0265, fixed=42.0
Combination 4: coupling=0.67, noise=0.0876, fixed=42.0

Understanding Axes

The AbstractAxis Interface

All axes share the same two-method interface:

  • generate_values(key=None): Returns sample values as a JAX array
  • size: Number of samples this axis generates
# Examine the interface
from tvboptim.types.spaces import GridAxis

grid = GridAxis(0.0, 1.0, 5)
print(f"Axis size: {grid.size}")
print(f"Generated values: {grid.generate_values()}")
print(f"Values shape: {grid.generate_values().shape}")
Axis size: 5
Generated values: [0.   0.25 0.5  0.75 1.  ]
Values shape: (5,)

Each axis type uses a different sampling strategy; the interface is always the same.

GridAxis - Deterministic Sampling

GridAxis samples a deterministic grid across a parameter range — identical values every run, no randomness.

# Basic grid sampling
grid_basic = GridAxis(0.0, 2.0, 11)
values_basic = grid_basic.generate_values()

print(f"Grid values: {values_basic}")
print(f"Spacing is uniform: {jnp.allclose(jnp.diff(values_basic), jnp.diff(values_basic)[0])}")

# Multiple grid densities
densities = [5, 10, 20]
for n in densities:
    grid = GridAxis(0.0, 1.0, n)
    values = grid.generate_values()
    print(f"n={n:2d}: {len(values)} values from {values[0]:.3f} to {values[-1]:.3f}")
Grid values: [0.        0.2       0.4       0.6       0.8       1.        1.2
 1.4       1.6       1.8000001 2.       ]
Spacing is uniform: True
n= 5: 5 values from 0.000 to 1.000
n=10: 10 values from 0.000 to 1.000
n=20: 20 values from 0.000 to 1.000
Show visualization code
# Visualize grid sampling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Plot 1: Basic grid values
ax1.scatter(range(len(values_basic)), values_basic, color='blue', s=50)
ax1.set_xlabel('Sample Index')
ax1.set_ylabel('Parameter Value')
ax1.set_title('GridAxis: Linear Spacing')
ax1.grid(True, alpha=0.3)

# Plot 2: Multiple grid densities
colors = ['red', 'green', 'blue']
for i, (n, color) in enumerate(zip(densities, colors)):
    grid = GridAxis(0.0, 1.0, n)
    values = grid.generate_values()
    ax2.scatter(values, [i] * len(values), color=color, s=30,
               label=f'n={n}', alpha=0.7)

ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Grid Density')
ax2.set_title('GridAxis: Different Densities')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Shape Broadcasting

GridAxis (and UniformAxis) support a shape parameter. Each sample becomes a constant array of that shape — useful for parameters that are shared uniformly across brain regions:

# Example: Regional coupling strengths (68 brain regions)
n_regions = 68
regional_grid = GridAxis(0.0, 1.0, 5, shape=(n_regions,))
regional_values = regional_grid.generate_values()

print(f"Regional grid shape: {regional_values.shape}")
print(f"Each sample has shape: {regional_values[0].shape}")
print(f"Sample 2 value: {regional_values[2][0]:.3f} (broadcasted to all regions)")
print(f"All regions identical per sample: {jnp.allclose(regional_values[2], regional_values[2][0])}")
Regional grid shape: (5, 68)
Each sample has shape: (68,)
Sample 2 value: 0.500 (broadcasted to all regions)
All regions identical per sample: True
Show visualization code
# Visualize the broadcasting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Show that each sample broadcasts the same value
sample_idx = 2
ax1.plot(regional_values[sample_idx], 'o-', markersize=3)
ax1.set_xlabel('Brain Region')
ax1.set_ylabel('Coupling Strength')
ax1.set_title(f'Sample {sample_idx}: Broadcasted Value = {regional_values[sample_idx][0]:.2f}')

# Show progression across samples
ax2.plot(range(5), regional_values[:, 0], 'o-', linewidth=2, markersize=8)
ax2.set_xlabel('Sample Index')
ax2.set_ylabel('Parameter Value')
ax2.set_title('Parameter Progression Across Samples')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

UniformAxis - Random Sampling

UniformAxis draws samples uniformly at random. Pass the same key and you get identical values.

# Reproducible random sampling
key = jax.random.key(42)
uniform = UniformAxis(0.0, 1.0, 100)

values1 = uniform.generate_values(key)
values2 = uniform.generate_values(key)  # Same key = same values
values3 = uniform.generate_values(jax.random.key(43))  # Different key

print(f"Same key gives identical results: {jnp.allclose(values1, values2)}")
print(f"Different key gives different results: {not jnp.allclose(values1, values3)}")
print(f"Mean value: {jnp.mean(values1):.3f} (should be ~0.5)")
print(f"Value range: [{jnp.min(values1):.3f}, {jnp.max(values1):.3f}]")
Same key gives identical results: True
Different key gives different results: True
Mean value: 0.514 (should be ~0.5)
Value range: [0.010, 0.992]
Show visualization code
# Visualize distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Histogram of uniform samples
ax1.hist(values1, bins=20, alpha=0.7, color='green', density=True)
ax1.axhline(y=1.0, color='red', linestyle='--', label='Expected density = 1.0')
ax1.set_xlabel('Parameter Value')
ax1.set_ylabel('Density')
ax1.set_title('UniformAxis Distribution')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Comparison of different sample sizes
sample_sizes = [10, 50, 200]
for size in sample_sizes:
    uniform_temp = UniformAxis(-2.0, 2.0, size)
    samples = uniform_temp.generate_values(jax.random.key(42))
    ax2.scatter(samples, [size] * len(samples), alpha=0.6, s=20,
               label=f'n={size}')

ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Sample Size')
ax2.set_title('UniformAxis: Different Sample Sizes')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

DataAxis - Predefined Values

DataAxis lets you use predefined sequences of values.

# Example: Testing specific coupling values from literature
literature_values = jnp.array([0.142, 0.284, 0.426, 0.568, 0.710])
data_axis = DataAxis(literature_values)

print(f"DataAxis size: {data_axis.size}")
print(f"Values: {data_axis.generate_values()}")

# Works with multidimensional data too
connectivity_matrices = jnp.array([
    [[1.0, 0.5], [0.5, 1.0]],  # Weak coupling
    [[1.0, 0.8], [0.8, 1.0]],  # Medium coupling
    [[1.0, 1.2], [1.2, 1.0]]   # Strong coupling
])

matrix_axis = DataAxis(connectivity_matrices)
print(f"\nMatrix axis shape: {matrix_axis.generate_values().shape}")
print("First connectivity matrix:")
print(matrix_axis.generate_values()[0])

# Create some creative data patterns
fib_values = jnp.array([1, 1, 2, 3, 5, 8, 13, 21])
fib_normalized = fib_values / fib_values.max()
print(f"\nFibonacci sequence (normalized): {fib_normalized}")

# Oscillatory pattern
t = jnp.linspace(0, 4*jnp.pi, 20)
oscillatory = 0.5 + 0.3 * jnp.sin(t)
print(f"Oscillatory pattern range: [{jnp.min(oscillatory):.3f}, {jnp.max(oscillatory):.3f}]")
DataAxis size: 5
Values: [0.142 0.284 0.426 0.568 0.71 ]

Matrix axis shape: (3, 2, 2)
First connectivity matrix:
[[1.  0.5]
 [0.5 1. ]]

Fibonacci sequence (normalized): [0.04761905 0.04761905 0.0952381  0.14285715 0.23809524 0.3809524
 0.61904764 1.        ]
Oscillatory pattern range: [0.201, 0.799]
Show visualization code
# Visualize different data patterns
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))

# 1. Literature values
ax = axes[0, 0]
ax.scatter(range(len(literature_values)), literature_values,
          color='red', s=100, marker='s')
ax.set_xlabel('Index')
ax.set_ylabel('Coupling Strength')
ax.set_title('DataAxis: Literature Values')
ax.grid(True, alpha=0.3)

# 2. Fibonacci sequence
fib_axis = DataAxis(fib_normalized)
ax = axes[0, 1]
ax.plot(fib_normalized, 'o-', color='purple', linewidth=2, markersize=8)
ax.set_xlabel('Index')
ax.set_ylabel('Normalized Value')
ax.set_title('DataAxis: Fibonacci Sequence')
ax.grid(True, alpha=0.3)

# 3. Oscillatory pattern
osc_axis = DataAxis(oscillatory)
ax = axes[1, 0]
ax.plot(oscillatory, 'o-', color='orange', linewidth=2, markersize=6)
ax.set_xlabel('Index')
ax.set_ylabel('Parameter Value')
ax.set_title('DataAxis: Oscillatory Pattern')
ax.grid(True, alpha=0.3)

# 4. Connectivity matrix visualization
ax = axes[1, 1]
im = ax.imshow(matrix_axis.generate_values()[1], cmap='viridis', vmin=0, vmax=1.2)
ax.set_title('DataAxis: Connectivity Matrix (Medium)')
plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()

NumPyroAxis - Distribution Sampling

NumPyroAxis wraps any NumPyro distribution, making it a natural fit for Bayesian workflows and uncertainty quantification.

import numpyro.distributions as dist

# Different distribution types
key = jax.random.key(42)

# Normal distribution - common for many biological parameters
normal_axis = NumPyroAxis(dist.Normal(0.5, 0.15), n=1000)
normal_samples = normal_axis.generate_values(key)

# Beta distribution - great for bounded parameters [0,1]
beta_axis = NumPyroAxis(dist.Beta(2.0, 5.0), n=1000)
beta_samples = beta_axis.generate_values(key)

# LogNormal - for positive-only parameters like time constants
lognormal_axis = NumPyroAxis(dist.LogNormal(0.0, 0.5), n=1000)
lognormal_samples = lognormal_axis.generate_values(key)

print(f"Normal samples: mean={jnp.mean(normal_samples):.3f}, std={jnp.std(normal_samples):.3f}")
print(f"Beta samples: mean={jnp.mean(beta_samples):.3f}, range=[{jnp.min(beta_samples):.3f}, {jnp.max(beta_samples):.3f}]")
print(f"LogNormal samples: mean={jnp.mean(lognormal_samples):.3f}, median={jnp.median(lognormal_samples):.3f}")
Normal samples: mean=0.506, std=0.146
Beta samples: mean=0.285, range=[0.007, 0.933]
LogNormal samples: mean=1.145, median=1.045
Show visualization code
# Visualize the distributions
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))

# Normal distribution
ax = axes[0, 0]
ax.hist(normal_samples, bins=50, density=True, alpha=0.7, color='blue')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Normal(0.5, 0.15)')
ax.grid(True, alpha=0.3)

# Beta distribution
ax = axes[0, 1]
ax.hist(beta_samples, bins=50, density=True, alpha=0.7, color='green')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Beta(2.0, 5.0)')
ax.grid(True, alpha=0.3)

# LogNormal distribution
ax = axes[1, 0]
ax.hist(lognormal_samples, bins=50, density=True, alpha=0.7, color='red')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('LogNormal(0.0, 0.5)')
ax.grid(True, alpha=0.3)

# Comparison of all three
ax = axes[1, 1]
ax.hist(normal_samples, bins=50, density=True, alpha=0.5, color='blue', label='Normal')
ax.hist(beta_samples, bins=50, density=True, alpha=0.5, color='green', label='Beta')
ax.hist(lognormal_samples[lognormal_samples < 5], bins=50, density=True, alpha=0.5, color='red', label='LogNormal')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Distribution Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Independent vs Broadcast Sampling

NumPyroAxis supports two sampling modes for different modeling scenarios:

key = jax.random.key(456)

# Independent mode: each element sampled independently
independent_axis = NumPyroAxis(
    dist.Normal(0.0, 1.0),
    n=3,
    sample_shape=(2, 4),
    broadcast_mode=False
)
independent_samples = independent_axis.generate_values(key)

# Broadcast mode: one sample per axis point, broadcast to shape
broadcast_axis = NumPyroAxis(
    dist.Normal(0.0, 1.0),
    n=3,
    sample_shape=(2, 4),
    broadcast_mode=True
)
broadcast_samples = broadcast_axis.generate_values(key)

print("Independent sampling:")
print(f"Shape: {independent_samples.shape}")
print("Sample 0 (each element different):")
print(independent_samples[0])
print()

print("Broadcast sampling:")
print(f"Shape: {broadcast_samples.shape}")
print("Sample 0 (all elements identical):")
print(broadcast_samples[0])
print(f"All elements identical: {jnp.allclose(broadcast_samples[0], broadcast_samples[0].flatten()[0])}")
Independent sampling:
Shape: (3, 2, 4)
Sample 0 (each element different):
[[ 0.991294   -1.2127206  -0.28042838  0.02664931]
 [-0.07197085  0.3858999  -0.5076543  -1.5987304 ]]

Broadcast sampling:
Shape: (3, 2, 4)
Sample 0 (all elements identical):
[[0.991294 0.991294 0.991294 0.991294]
 [0.991294 0.991294 0.991294 0.991294]]
All elements identical: True
Show visualization code
# Visualize the difference
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Independent sampling heatmap
im1 = ax1.imshow(independent_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax1.set_title('Independent Sampling\n(Each element different)')
ax1.set_xlabel('Dimension 2')
ax1.set_ylabel('Dimension 1')
plt.colorbar(im1, ax=ax1, fraction=0.046)

# Broadcast sampling heatmap
im2 = ax2.imshow(broadcast_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax2.set_title('Broadcast Sampling\n(All elements identical)')
ax2.set_xlabel('Dimension 2')
ax2.set_ylabel('Dimension 1')
plt.colorbar(im2, ax=ax2, fraction=0.046)

plt.tight_layout()
plt.show()

Composing with Space

Basic Space Creation

Space takes a state tree — axes where you want variation, plain values where you don’t — and generates all parameter combinations from it.

# Create a realistic brain simulation parameter space
simulation_state = {
    'coupling': {
        'strength': GridAxis(0.1, 0.8, 5),      # Coupling strength
        'delay': UniformAxis(5.0, 25.0, 3),     # Transmission delays (ms)
    },
    'model': {
        'noise_amplitude': DataAxis([0.01, 0.02, 0.05]),  # Noise levels from pilot study
        'time_constant': 10.0,                    # Fixed parameter
    },
    'simulation': {
        'dt': 0.1,                               # Fixed time step
        'duration': 1000,                        # Fixed duration
    }
}

# Create the space
space = Space(simulation_state, mode='product', key=jax.random.key(789))
print(f"Parameter space contains {len(space)} combinations")
print(f"Combination modes: 5 × 3 × 3 = {5*3*3} (Grid × Uniform × Data)")

# Look at a few parameter combinations
print("\nFirst 3 parameter combinations:")
print("=" * 50)

for i, params in enumerate(space):
    if i >= 3: break
    print(f"\nCombination {i}:")
    print(f"  Coupling strength: {params['coupling']['strength']:.3f}")
    print(f"  Coupling delay: {params['coupling']['delay']:.2f} ms")
    print(f"  Noise amplitude: {params['model']['noise_amplitude']:.3f}")
    print(f"  Time constant: {params['model']['time_constant']} (fixed)")
    print(f"  Simulation dt: {params['simulation']['dt']} (fixed)")
Parameter space contains 45 combinations
Combination modes: 5 × 3 × 3 = 45 (Grid × Uniform × Data)

First 3 parameter combinations:
==================================================

Combination 0:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.010
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination 1:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.020
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination 2:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.050
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination Modes: Product vs Zip

The two combination modes serve different exploration strategies:

Product Mode (Cartesian Product)

Tests every combination of parameter values. Use this for systematic exploration.

# Product mode: systematic exploration
product_state = {
    'param_a': GridAxis(0.0, 1.0, 3),
    'param_b': GridAxis(0.0, 1.0, 2),
}

product_space = Space(product_state, mode='product')
print(f"Product mode: {len(product_space)} combinations")

print("\nAll combinations (product mode):")
for i, params in enumerate(product_space):
    print(f"  {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
Product mode: 6 combinations

All combinations (product mode):
  0: a=0.0, b=0.0
  1: a=0.0, b=1.0
  2: a=0.5, b=0.0
  3: a=0.5, b=1.0
  4: a=1.0, b=0.0
  5: a=1.0, b=1.0

Zip Mode (Parallel Sampling)

Pairs corresponding elements from each axis. Use this when parameters must move together.

# Zip mode: matched sampling
zip_space = Space(product_state, mode='zip')
print(f"Zip mode: {len(zip_space)} combinations (uses minimum axis size)")

print("\nMatched combinations (zip mode):")
for i, params in enumerate(zip_space):
    print(f"  {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
WARNING: In zip mode, effective axes have different sizes [3, 2]. Using minimum size 2, losing 1 combinations.
Zip mode: 2 combinations (uses minimum axis size)

Matched combinations (zip mode):
  0: a=0.0, b=0.0
  1: a=0.5, b=1.0
Show visualization code
# Create a more interesting comparison
comparison_state = {
    'x': GridAxis(0.0, 2.0, 8),
    'y': GridAxis(0.0, 2.0, 5),
}

product_comp = Space(comparison_state, mode='product')
zip_comp = Space(comparison_state, mode='zip')

# Extract coordinates for plotting
product_x = [p['x'] for p in product_comp]
product_y = [p['y'] for p in product_comp]

zip_x = [p['x'] for p in zip_comp]
zip_y = [p['y'] for p in zip_comp]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 5))

# Product mode visualization
ax1.scatter(product_x, product_y, color='blue', s=50, alpha=0.7)
ax1.set_xlabel('Parameter X')
ax1.set_ylabel('Parameter Y')
ax1.set_title(f'Product Mode\n({len(product_comp)} combinations)')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-0.2, 2.2)
ax1.set_ylim(-0.2, 2.2)

# Zip mode visualization
ax2.scatter(zip_x, zip_y, color='red', s=100, alpha=0.7)
ax2.plot(zip_x, zip_y, 'r--', alpha=0.5, linewidth=1)
ax2.set_xlabel('Parameter X')
ax2.set_ylabel('Parameter Y')
ax2.set_title(f'Zip Mode\n({len(zip_comp)} combinations)')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(-0.2, 2.2)
ax2.set_ylim(-0.2, 2.2)

plt.tight_layout()
plt.show()
WARNING: In zip mode, effective axes have different sizes [8, 5]. Using minimum size 5, losing 3 combinations.

Grouped Axes: Mixing Combination Modes

In many real-world scenarios, you need both zip and product behavior in the same space. For example, you might have matched experimental settings that must move in lockstep, combined with a parameter grid you want to explore exhaustively.

The group parameter on axes solves this. Axes sharing the same group label are zipped together into a single composite axis, regardless of where they sit in the state tree. Then all groups and ungrouped axes combine using the Space’s mode (typically product).

Basic Example

# Three matched settings that must stay together
setting_values_a = jnp.array([1.0, 2.0, 3.0])
setting_values_b = jnp.array([10.0, 20.0, 30.0])

grouped_state = {
    'setting_a': DataAxis(setting_values_a, group='settings'),
    'setting_b': DataAxis(setting_values_b, group='settings'),
    'param_x': GridAxis(0.0, 1.0, 4),
}

space = Space(grouped_state, mode='product', key=jax.random.key(42))
print(f"Total combinations: {len(space)}")
print(f"  = 3 settings x 4 param_x = {3 * 4}")
print(f"  (NOT 3 x 3 x 4 = {3 * 3 * 4})")

print("\nAll combinations:")
for i, params in enumerate(space):
    print(f"  {i}: setting_a={params['setting_a']:.0f}, "
          f"setting_b={params['setting_b']:.0f}, "
          f"param_x={params['param_x']:.2f}")
Total combinations: 12
  = 3 settings x 4 param_x = 12
  (NOT 3 x 3 x 4 = 36)

All combinations:
  0: setting_a=1, setting_b=10, param_x=0.00
  1: setting_a=1, setting_b=10, param_x=0.33
  2: setting_a=1, setting_b=10, param_x=0.67
  3: setting_a=1, setting_b=10, param_x=1.00
  4: setting_a=2, setting_b=20, param_x=0.00
  5: setting_a=2, setting_b=20, param_x=0.33
  6: setting_a=2, setting_b=20, param_x=0.67
  7: setting_a=2, setting_b=20, param_x=1.00
  8: setting_a=3, setting_b=30, param_x=0.00
  9: setting_a=3, setting_b=30, param_x=0.33
  10: setting_a=3, setting_b=30, param_x=0.67
  11: setting_a=3, setting_b=30, param_x=1.00

Notice that setting_a and setting_b always move together (1 with 10, 2 with 20, 3 with 30) while param_x varies independently across its full grid.

Groups work across subtrees – grouped axes don’t need to be co-located in the state tree. You can also use multiple independent groups (each with its own label), and they each zip internally before being combined via the Space’s mode. The group label can be any hashable value (strings, ints, tuples, etc.).

Iteration & Access Patterns

Space supports standard Python iteration, integer indexing (including negative), and slicing. Slices return a new Space.

demo_state = {
    'coupling': GridAxis(0.0, 1.0, 6),
    'noise': UniformAxis(0.01, 0.1, 6),
    'threshold': 0.5
}
demo_space = Space(demo_state, mode='zip', key=jax.random.key(42))

# Iterate
for i, params in enumerate(demo_space):
    if i >= 3: break
    print(f"  {i}: coupling={params['coupling']:.2f}, noise={params['noise']:.4f}")

# Index and slice
print(demo_space[0]['coupling'], demo_space[-1]['coupling'])
subset = demo_space[1:4]
print(f"Subset: {len(subset)} combinations")
  0: coupling=0.00, noise=0.0701
  1: coupling=0.20, noise=0.0749
  2: coupling=0.40, noise=0.0214
0.0 1.0
Subset: 3 combinations

Converting to DataFrames

Both spaces and execution results can be converted to pandas DataFrames for easy analysis and plotting. Column names are derived automatically from the pytree structure of your state.

# Parameter space as a DataFrame
df_params = demo_space.to_dataframe()
print(df_params.head())
   coupling     noise
0       0.0  0.070052
1       0.2  0.074934
2       0.4  0.021412
3       0.6  0.041885
4       0.8  0.047898

After running a computation, Result.to_dataframe() combines parameters and results into a single DataFrame, keeping them correctly aligned:

from tvboptim.execution import SequentialExecution

def compute_metric(state):
    return {'activity': state['coupling'] * 2 + state['noise'],
            'ratio': state['coupling'] / (state['noise'] + 1e-6)}

result = SequentialExecution(compute_metric, demo_space).run()
df = result.to_dataframe()
print(df.head())

  0%|          | 0/6 [00:00<?, ?it/s]
 17%|█▋        | 1/6 [00:00<00:00,  9.39it/s]
100%|██████████| 6/6 [00:00<00:00, 54.72it/s]
   coupling     noise  activity      ratio
0       0.0  0.070052  0.070052   0.000000
1       0.2  0.074934  0.474934   2.668987
2       0.4  0.021412  0.821412  18.680656
3       0.6  0.041885  1.241885  14.324623
4       0.8  0.047898  1.647898  16.701727
Show plotting example
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

ax1.scatter(df['coupling'], df['activity'])
ax1.set_xlabel('Coupling')
ax1.set_ylabel('Activity')
ax1.set_title('Coupling vs Activity')
ax1.grid(True, alpha=0.3)

ax2.scatter(df['noise'], df['ratio'], c=df['coupling'], cmap='viridis')
ax2.set_xlabel('Noise')
ax2.set_ylabel('Ratio')
ax2.set_title('Noise vs Ratio (colored by Coupling)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Nested state structures produce dot-separated column names (e.g. model.G, noise.sigma), and non-scalar results like arrays or matrices are stored as object cells in the DataFrame.

Execution: Running Functions Over Spaces

The axis and space system works with any pure JAX function — not just TVB network models. The examples here use a real brain network, which is the typical use case.

import copy
from tvboptim.execution import SequentialExecution, ParallelExecution
from tvboptim.experimental.network_dynamics import Network, prepare
from tvboptim.experimental.network_dynamics.dynamics.tvb import ReducedWongWang
from tvboptim.experimental.network_dynamics.coupling import LinearCoupling
from tvboptim.experimental.network_dynamics.graph import DenseGraph
from tvboptim.experimental.network_dynamics.noise import AdditiveNoise
from tvboptim.experimental.network_dynamics.solvers import Heun

n_nodes = 8
network = Network(
    dynamics=ReducedWongWang(),
    coupling={'instant': LinearCoupling(incoming_states='S', G=0.5)},
    graph=DenseGraph(jnp.ones((n_nodes, n_nodes)) - jnp.eye(n_nodes)),
    noise=AdditiveNoise(sigma=1e-3, key=jax.random.key(0)),
)
solve_fn, base_cfg = prepare(network, Heun(), t0=0.0, t1=5.0, dt=0.1)
print(f"Prepared: {n_nodes}-node RWW network, {int(5.0/0.1)} time steps")
Prepared: 8-node RWW network, 50 time steps

prepare returns a compiled solve function and a config object. Axes attach directly to config leaves — solve_fn stays constant across the whole sweep.

def observe(cfg):
    ys = solve_fn(cfg).ys  # shape: [n_steps, n_state_vars, n_nodes]
    return {'mean_S': ys.mean(), 'std_S': ys.std()}

Sequential Execution

Sequential execution processes one parameter combination at a time. Use it for debugging, memory-constrained environments, or when your model doesn’t vectorize cleanly.

seq_state = copy.deepcopy(base_cfg)
seq_state.coupling.instant.G = GridAxis(0.1, 1.0, 8)

seq_space = Space(seq_state, mode='product')
seq_results = SequentialExecution(observe, seq_space).run()

df_seq = seq_results.to_dataframe()
print(df_seq)

  0%|          | 0/8 [00:00<?, ?it/s]
 12%|█▎        | 1/8 [00:01<00:07,  1.11s/it]
 25%|██▌       | 2/8 [00:01<00:03,  1.79it/s]
 38%|███▊      | 3/8 [00:01<00:01,  2.62it/s]
 50%|█████     | 4/8 [00:01<00:01,  3.32it/s]
 62%|██████▎   | 5/8 [00:01<00:00,  3.89it/s]
 75%|███████▌  | 6/8 [00:01<00:00,  4.36it/s]
 88%|████████▊ | 7/8 [00:02<00:00,  4.80it/s]
100%|██████████| 8/8 [00:02<00:00,  5.03it/s]
100%|██████████| 8/8 [00:02<00:00,  3.43it/s]
   coupling.instant.G    mean_S     std_S
0            0.100000  0.101733  0.001502
1            0.228571  0.105092  0.003031
2            0.357143  0.110537  0.006222
3            0.485714  0.118753  0.011529
4            0.614286  0.130028  0.019266
5            0.742857  0.144003  0.029189
6            0.871429  0.160107  0.040861
7            1.000000  0.177952  0.054001
Show visualization code
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

ax1.plot(df_seq['coupling.instant.G'], df_seq['mean_S'], 'o-')
ax1.set_xlabel('Coupling G')
ax1.set_ylabel('Mean S')
ax1.set_title('Mean activity vs coupling strength')
ax1.grid(True, alpha=0.3)

ax2.plot(df_seq['coupling.instant.G'], df_seq['std_S'], 'o-', color='orange')
ax2.set_xlabel('Coupling G')
ax2.set_ylabel('Std S')
ax2.set_title('Activity variability vs coupling strength')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Parallel Execution

Parallel execution uses JAX’s pmap and vmap for vectorized computation across multiple devices. With large parameter spaces, this is where the runtime difference becomes significant.

par_state = copy.deepcopy(base_cfg)
par_state.coupling.instant.G = GridAxis(0.1, 1.0, 8)
par_state.noise.sigma = GridAxis(1e-4, 5e-3, 4)

par_space = Space(par_state, mode='product')  # 8 × 4 = 32 combinations
print(f"Parameter space: {len(par_space)} combinations")

par_results = ParallelExecution(observe, par_space, n_vmap=8, n_pmap=1).run()
print(f"Done: {len(par_results)} results")
Parameter space: 32 combinations
Done: 32 results
Show visualization code
df_par = par_results.to_dataframe()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

for sigma, group in df_par.groupby('noise.sigma'):
    ax1.plot(group['coupling.instant.G'], group['mean_S'],
             'o-', alpha=0.7, label=f'σ={sigma:.4f}')
ax1.set_xlabel('Coupling G')
ax1.set_ylabel('Mean S')
ax1.set_title('G sweep across noise amplitudes')
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)

scatter = ax2.scatter(df_par['coupling.instant.G'], df_par['mean_S'],
                      c=df_par['noise.sigma'], cmap='plasma', s=40)
ax2.set_xlabel('Coupling G')
ax2.set_ylabel('Mean S')
ax2.set_title('Mean S across parameter space')
plt.colorbar(scatter, ax=ax2, label='noise σ')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Performance Comparison

For trivial models the JIT and dispatch overhead of parallel execution outweighs the speedup — sequential wins at small scale. The crossover depends on model complexity and problem size.

import time

timing_state = {
    'param1': GridAxis(0.0, 1.0, 5),
    'param2': UniformAxis(0.0, 1.0, 5)
}
timing_space = Space(timing_state, mode='product', key=jax.random.key(789))

def simple_model(state):
    return {'result': state['param1'] * state['param2']}

print("Performance comparison on 25 parameter combinations:")
print("=" * 55)

seq_exec = SequentialExecution(simple_model, timing_space)
start = time.time()
seq_results = seq_exec.run()
seq_time = time.time() - start
print(f"Sequential execution: {seq_time:.4f} seconds")

par_exec = ParallelExecution(simple_model, timing_space, n_vmap=5, n_pmap=1)
start = time.time()
par_results = par_exec.run()
par_time = time.time() - start
print(f"Parallel execution:   {par_time:.4f} seconds")

if seq_time > par_time:
    print(f"Parallel speedup:     {seq_time / par_time:.1f}x faster")
else:
    print("Sequential was faster (overhead dominates for small problems)")
Performance comparison on 25 parameter combinations:
=======================================================

  0%|          | 0/25 [00:00<?, ?it/s]
  4%|▍         | 1/25 [00:00<00:05,  4.52it/s]
100%|██████████| 25/25 [00:00<00:00, 109.71it/s]
Sequential execution: 0.2304 seconds
Parallel execution:   0.0512 seconds
Parallel speedup:     4.5x faster

When to Use Each Approach

Sequential Execution:

  • Debugging: Easy to trace through individual parameter combinations
  • Memory constraints: Processes one combination at a time
  • Complex models: When models don’t easily vectorize
  • Small parameter spaces: Overhead of parallelization not worth it

Parallel Execution:

  • Large parameter spaces: Massive speedups for hundreds/thousands of combinations
  • JAX-compatible models: Models that can be easily vectorized
  • Production runs: When you need results fast
  • Multi-device setups: Can use multiple GPUs/TPUs

Integration with Spaces

Both execution types work with all space features. Any axis type can be placed on any config leaf:

mixed_state = copy.deepcopy(base_cfg)
mixed_state.coupling.instant.G = NumPyroAxis(dist.Beta(2.0, 5.0), n=4)
mixed_state.noise.sigma = DataAxis(jnp.array([1e-4, 5e-4, 1e-3]))

mixed_space = Space(mixed_state, mode='product', key=jax.random.key(0))
print(f"{len(mixed_space)} combinations (4 coupling samples × 3 noise levels)")

result = SequentialExecution(observe, mixed_space).run()
print(result.to_dataframe())
12 combinations (4 coupling samples × 3 noise levels)

  0%|          | 0/12 [00:00<?, ?it/s]
  8%|▊         | 1/12 [00:01<00:11,  1.02s/it]
 17%|█▋        | 2/12 [00:01<00:05,  1.92it/s]
 25%|██▌       | 3/12 [00:01<00:03,  2.77it/s]
 33%|███▎      | 4/12 [00:01<00:03,  2.21it/s]
 42%|████▏     | 5/12 [00:02<00:02,  2.82it/s]
 50%|█████     | 6/12 [00:02<00:01,  3.42it/s]
 58%|█████▊    | 7/12 [00:02<00:01,  3.98it/s]
 67%|██████▋   | 8/12 [00:02<00:00,  4.45it/s]
 75%|███████▌  | 9/12 [00:02<00:00,  4.81it/s]
 83%|████████▎ | 10/12 [00:02<00:00,  5.12it/s]
 92%|█████████▏| 11/12 [00:03<00:00,  5.31it/s]
100%|██████████| 12/12 [00:03<00:00,  5.41it/s]
100%|██████████| 12/12 [00:03<00:00,  3.60it/s]
    coupling.instant.G  noise.sigma    mean_S     std_S
0             0.681876       0.0001  0.136815  0.024235
1             0.681876       0.0005  0.136931  0.024231
2             0.681876       0.0010  0.137077  0.024240
3             0.138849       0.0001  0.102423  0.001367
4             0.138849       0.0005  0.102488  0.001468
5             0.138849       0.0010  0.102571  0.001803
6             0.532662       0.0001  0.122302  0.014069
7             0.532662       0.0005  0.122400  0.014060
8             0.532662       0.0010  0.122522  0.014074
9             0.422139       0.0001  0.114119  0.008583
10            0.422139       0.0005  0.114204  0.008576
11            0.422139       0.0010  0.114311  0.008609

Scanning Noise Seeds for Stochastic Simulations

When the model is stochastic (a Network with an AbstractNoise attached, integrated with a native solver), the PRNG key that drives the per-step noise increments lives at config.noise.key. Because the key is a regular leaf on the prepared config, you can place an axis on it the same way you place axes on any other parameter — no wrapper function, no per-point key swapping. The axis system substitutes the leaf with the per-point value before each call.

# Attach axes directly to the config — the leaves they target are
# substituted with per-point values at execution time.
grid_state = copy.deepcopy(base_cfg)
grid_state.coupling.instant.G = GridAxis(0.3, 1.5, 8)                       # Coupling sweep
grid_state.noise.key          = DataAxis(jax.random.split(jax.random.key(0), 8))  # Noise replicates

space = Space(grid_state, mode='product')

def observation(cfg):
    return solve_fn(cfg).ys.mean()

observations = ParallelExecution(observation, space, n_vmap=8, n_pmap=1).run()

Two things to know:

  • No re-prepare per seed. The PRNG key is a config leaf, so the axis system substitutes it like any other parameter. jax.vmap over the resulting key batch dimension batches the single jax.random.normal(config.noise.key, ...) call inside the solve function, with no compile churn beyond the initial JIT.
  • Common random numbers across parameter sweeps. Place the DataAxis of keys on its own (so it shares one key across all G values), or use mode='zip' to pair them — every G point then sees the same realised noise trajectory. This is the classic variance-reduction trick for finite-difference sensitivity estimates.

For workflows that need the increments materialised as an addressable tensor (NumPyro inference over Brownian increments, deterministic replay of a recorded trajectory), populate config._internal.noise_samples with an array of shape [n_steps, n_noise_states, n_nodes]. The scan branches on this field at trace time and uses it in place of in-scan sampling. Flipping between None and an array triggers a one-time JIT retrace. The injection slot is native-only — the Diffrax dispatch supports the same config.noise.key swap pattern shown above, but its VirtualBrownianTree cannot consume a pre-sampled array.