Axes and Spaces: Systematic Parameter Exploration in TVB-Optim

Introduction & Overview

The Axes and Spaces system in TVB-Optim provides a powerful, composable framework for systematic parameter exploration in brain network simulations. This system enables you to:

  • Define parameter sampling strategies using different axis types
  • Compose complex parameter spaces from multiple axes
  • Leverage JAX transformations for efficient parallel execution
  • Seamlessly integrate with optimization and analysis workflows

Why This Design?

Traditional parameter exploration often requires writing custom loops and managing complex indexing. TVB-Optim’s axis system provides a flexible and composable approach to make things easier.

Let’s start with a quick preview of what we’ll build:

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from tvboptim.types.spaces import Space, GridAxis, UniformAxis, DataAxis, NumPyroAxis

# Quick preview: A mixed parameter space
preview_state = {
    'coupling_strength': GridAxis(0.0, 2.0, 4),
    'noise_level': UniformAxis(0.01, 0.1, 3),
    'fixed_param': 42.0
}

space = Space(preview_state, mode='product', key=jax.random.key(123))
print(f"Created parameter space with {len(space)} combinations")

# Show first few parameter combinations
for i, params in enumerate(space):
    if i >= 5: break
    print(f"Combination {i}: coupling={params['coupling_strength']:.2f}, "
          f"noise={params['noise_level']:.4f}, fixed={params['fixed_param']}")
Created parameter space with 12 combinations
Combination 0: coupling=0.00, noise=0.0265, fixed=42.0
Combination 1: coupling=0.00, noise=0.0876, fixed=42.0
Combination 2: coupling=0.00, noise=0.0204, fixed=42.0
Combination 3: coupling=0.67, noise=0.0265, fixed=42.0
Combination 4: coupling=0.67, noise=0.0876, fixed=42.0

Understanding Axes

The AbstractAxis Interface

All axes in TVB-Optim implement a simple but powerful interface:

  • generate_values(key=None): Produces sample values as JAX arrays
  • size: Returns the number of samples this axis generates
# Let's examine the interface
from tvboptim.types.spaces import GridAxis

grid = GridAxis(0.0, 1.0, 5)
print(f"Axis size: {grid.size}")
print(f"Generated values: {grid.generate_values()}")
print(f"Values shape: {grid.generate_values().shape}")
Axis size: 5
Generated values: [0.   0.25 0.5  0.75 1.  ]
Values shape: (5,)

Each axis type can implement completely different sampling strategies while maintaining a consistent interface.

GridAxis - Deterministic Sampling

GridAxis provides systematic, deterministic sampling across parameter ranges. Perfect for sensitivity analysis and systematic exploration.

# Basic grid sampling
grid_basic = GridAxis(0.0, 2.0, 11)
values_basic = grid_basic.generate_values()

print(f"Grid values: {values_basic}")
print(f"Spacing is uniform: {jnp.allclose(jnp.diff(values_basic), jnp.diff(values_basic)[0])}")

# Multiple grid densities
densities = [5, 10, 20]
for n in densities:
    grid = GridAxis(0.0, 1.0, n)
    values = grid.generate_values()
    print(f"n={n:2d}: {len(values)} values from {values[0]:.3f} to {values[-1]:.3f}")
Grid values: [0.        0.2       0.4       0.6       0.8       1.        1.2
 1.4       1.6       1.8000001 2.       ]
Spacing is uniform: True
n= 5: 5 values from 0.000 to 1.000
n=10: 10 values from 0.000 to 1.000
n=20: 20 values from 0.000 to 1.000
Show visualization code
# Visualize grid sampling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Plot 1: Basic grid values
ax1.scatter(range(len(values_basic)), values_basic, color='blue', s=50)
ax1.set_xlabel('Sample Index')
ax1.set_ylabel('Parameter Value')
ax1.set_title('GridAxis: Linear Spacing')
ax1.grid(True, alpha=0.3)

# Plot 2: Multiple grid densities
colors = ['red', 'green', 'blue']
for i, (n, color) in enumerate(zip(densities, colors)):
    grid = GridAxis(0.0, 1.0, n)
    values = grid.generate_values()
    ax2.scatter(values, [i] * len(values), color=color, s=30,
               label=f'n={n}', alpha=0.7)

ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Grid Density')
ax2.set_title('GridAxis: Different Densities')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Shape Broadcasting with GridAxis

GridAxis supports shape broadcasting, which is useful for regional brain parameters in subsequent optimizations:

# Example: Regional coupling strengths (68 brain regions)
n_regions = 68
regional_grid = GridAxis(0.0, 1.0, 5, shape=(n_regions,))
regional_values = regional_grid.generate_values()

print(f"Regional grid shape: {regional_values.shape}")
print(f"Each sample has shape: {regional_values[0].shape}")
print(f"Sample 2 value: {regional_values[2][0]:.3f} (broadcasted to all regions)")
print(f"All regions identical per sample: {jnp.allclose(regional_values[2], regional_values[2][0])}")
Regional grid shape: (5, 68)
Each sample has shape: (68,)
Sample 2 value: 0.500 (broadcasted to all regions)
All regions identical per sample: True
Show visualization code
# Visualize the broadcasting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Show that each sample broadcasts the same value
sample_idx = 2
ax1.plot(regional_values[sample_idx], 'o-', markersize=3)
ax1.set_xlabel('Brain Region')
ax1.set_ylabel('Coupling Strength')
ax1.set_title(f'Sample {sample_idx}: Broadcasted Value = {regional_values[sample_idx][0]:.2f}')

# Show progression across samples
ax2.plot(range(5), regional_values[:, 0], 'o-', linewidth=2, markersize=8)
ax2.set_xlabel('Sample Index')
ax2.set_ylabel('Parameter Value')
ax2.set_title('Parameter Progression Across Samples')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

UniformAxis - Random Sampling

UniformAxis provides random sampling from uniform distributions, essential for more efficient stochastic exploration.

# Reproducible random sampling
key = jax.random.key(42)
uniform = UniformAxis(0.0, 1.0, 100)

values1 = uniform.generate_values(key)
values2 = uniform.generate_values(key)  # Same key = same values
values3 = uniform.generate_values(jax.random.key(43))  # Different key

print(f"Same key gives identical results: {jnp.allclose(values1, values2)}")
print(f"Different key gives different results: {not jnp.allclose(values1, values3)}")
print(f"Mean value: {jnp.mean(values1):.3f} (should be ~0.5)")
print(f"Value range: [{jnp.min(values1):.3f}, {jnp.max(values1):.3f}]")
Same key gives identical results: True
Different key gives different results: True
Mean value: 0.514 (should be ~0.5)
Value range: [0.010, 0.992]
Show visualization code
# Visualize distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Histogram of uniform samples
ax1.hist(values1, bins=20, alpha=0.7, color='green', density=True)
ax1.axhline(y=1.0, color='red', linestyle='--', label='Expected density = 1.0')
ax1.set_xlabel('Parameter Value')
ax1.set_ylabel('Density')
ax1.set_title('UniformAxis Distribution')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Comparison of different sample sizes
sample_sizes = [10, 50, 200]
for size in sample_sizes:
    uniform_temp = UniformAxis(-2.0, 2.0, size)
    samples = uniform_temp.generate_values(jax.random.key(42))
    ax2.scatter(samples, [size] * len(samples), alpha=0.6, s=20,
               label=f'n={size}')

ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Sample Size')
ax2.set_title('UniformAxis: Different Sample Sizes')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Shape Broadcasting with UniformAxis

UniformAxis also supports shape broadcasting, useful for parameters that need the same random value across multiple dimensions:

# Compare independent vs broadcast sampling
key = jax.random.key(123)

# Independent sampling: each element is different
uniform_independent = UniformAxis(0.0, 1.0, 3, shape=None)
independent_values = uniform_independent.generate_values(key)

# Broadcast sampling: same value across shape dimensions
uniform_broadcast = UniformAxis(0.0, 1.0, 3, shape=(4, 2))
broadcast_values = uniform_broadcast.generate_values(key)

print("Independent sampling:")
print(f"Shape: {independent_values.shape}")
print(f"Values: {independent_values}")

print("\nBroadcast sampling:")
print(f"Shape: {broadcast_values.shape}")
print("First sample (4x2 matrix):")
print(broadcast_values[0])
print(f"All values in first sample identical: {jnp.allclose(broadcast_values[0], broadcast_values[0].flatten()[0])}")
Independent sampling:
Shape: (3,)
Values: [0.9490746 0.7997726 0.5088254]

Broadcast sampling:
Shape: (3, 4, 2)
First sample (4x2 matrix):
[[0.9490746 0.9490746]
 [0.9490746 0.9490746]
 [0.9490746 0.9490746]
 [0.9490746 0.9490746]]
All values in first sample identical: True

DataAxis - Predefined Values

DataAxis lets you use predefined sequences of values.

# Example: Testing specific coupling values from literature
literature_values = jnp.array([0.142, 0.284, 0.426, 0.568, 0.710])
data_axis = DataAxis(literature_values)

print(f"DataAxis size: {data_axis.size}")
print(f"Values: {data_axis.generate_values()}")

# Works with multidimensional data too
connectivity_matrices = jnp.array([
    [[1.0, 0.5], [0.5, 1.0]],  # Weak coupling
    [[1.0, 0.8], [0.8, 1.0]],  # Medium coupling
    [[1.0, 1.2], [1.2, 1.0]]   # Strong coupling
])

matrix_axis = DataAxis(connectivity_matrices)
print(f"\nMatrix axis shape: {matrix_axis.generate_values().shape}")
print("First connectivity matrix:")
print(matrix_axis.generate_values()[0])

# Create some creative data patterns
fib_values = jnp.array([1, 1, 2, 3, 5, 8, 13, 21])
fib_normalized = fib_values / fib_values.max()
print(f"\nFibonacci sequence (normalized): {fib_normalized}")

# Oscillatory pattern
t = jnp.linspace(0, 4*jnp.pi, 20)
oscillatory = 0.5 + 0.3 * jnp.sin(t)
print(f"Oscillatory pattern range: [{jnp.min(oscillatory):.3f}, {jnp.max(oscillatory):.3f}]")
DataAxis size: 5
Values: [0.142 0.284 0.426 0.568 0.71 ]

Matrix axis shape: (3, 2, 2)
First connectivity matrix:
[[1.  0.5]
 [0.5 1. ]]

Fibonacci sequence (normalized): [0.04761905 0.04761905 0.0952381  0.14285715 0.23809524 0.3809524
 0.61904764 1.        ]
Oscillatory pattern range: [0.201, 0.799]
Show visualization code
# Visualize different data patterns
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))

# 1. Literature values
ax = axes[0, 0]
ax.scatter(range(len(literature_values)), literature_values,
          color='red', s=100, marker='s')
ax.set_xlabel('Index')
ax.set_ylabel('Coupling Strength')
ax.set_title('DataAxis: Literature Values')
ax.grid(True, alpha=0.3)

# 2. Fibonacci sequence
fib_axis = DataAxis(fib_normalized)
ax = axes[0, 1]
ax.plot(fib_normalized, 'o-', color='purple', linewidth=2, markersize=8)
ax.set_xlabel('Index')
ax.set_ylabel('Normalized Value')
ax.set_title('DataAxis: Fibonacci Sequence')
ax.grid(True, alpha=0.3)

# 3. Oscillatory pattern
osc_axis = DataAxis(oscillatory)
ax = axes[1, 0]
ax.plot(oscillatory, 'o-', color='orange', linewidth=2, markersize=6)
ax.set_xlabel('Index')
ax.set_ylabel('Parameter Value')
ax.set_title('DataAxis: Oscillatory Pattern')
ax.grid(True, alpha=0.3)

# 4. Connectivity matrix visualization
ax = axes[1, 1]
im = ax.imshow(matrix_axis.generate_values()[1], cmap='viridis', vmin=0, vmax=1.2)
ax.set_title('DataAxis: Connectivity Matrix (Medium)')
plt.colorbar(im, ax=ax, fraction=0.046)

plt.tight_layout()
plt.show()

NumPyroAxis - Distribution Sampling

NumPyroAxis brings the full power of probability distributions to parameter exploration, enabling sophisticated uncertainty quantification and Bayesian workflows.

import numpyro.distributions as dist

# Different distribution types
key = jax.random.key(42)

# Normal distribution - common for many biological parameters
normal_axis = NumPyroAxis(dist.Normal(0.5, 0.15), n=1000)
normal_samples = normal_axis.generate_values(key)

# Beta distribution - great for bounded parameters [0,1]
beta_axis = NumPyroAxis(dist.Beta(2.0, 5.0), n=1000)
beta_samples = beta_axis.generate_values(key)

# LogNormal - for positive-only parameters like time constants
lognormal_axis = NumPyroAxis(dist.LogNormal(0.0, 0.5), n=1000)
lognormal_samples = lognormal_axis.generate_values(key)

print(f"Normal samples: mean={jnp.mean(normal_samples):.3f}, std={jnp.std(normal_samples):.3f}")
print(f"Beta samples: mean={jnp.mean(beta_samples):.3f}, range=[{jnp.min(beta_samples):.3f}, {jnp.max(beta_samples):.3f}]")
print(f"LogNormal samples: mean={jnp.mean(lognormal_samples):.3f}, median={jnp.median(lognormal_samples):.3f}")
Normal samples: mean=0.506, std=0.146
Beta samples: mean=0.285, range=[0.007, 0.933]
LogNormal samples: mean=1.145, median=1.045
Show visualization code
# Visualize the distributions
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))

# Normal distribution
ax = axes[0, 0]
ax.hist(normal_samples, bins=50, density=True, alpha=0.7, color='blue')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Normal(0.5, 0.15)')
ax.grid(True, alpha=0.3)

# Beta distribution
ax = axes[0, 1]
ax.hist(beta_samples, bins=50, density=True, alpha=0.7, color='green')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Beta(2.0, 5.0)')
ax.grid(True, alpha=0.3)

# LogNormal distribution
ax = axes[1, 0]
ax.hist(lognormal_samples, bins=50, density=True, alpha=0.7, color='red')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('LogNormal(0.0, 0.5)')
ax.grid(True, alpha=0.3)

# Comparison of all three
ax = axes[1, 1]
ax.hist(normal_samples, bins=50, density=True, alpha=0.5, color='blue', label='Normal')
ax.hist(beta_samples, bins=50, density=True, alpha=0.5, color='green', label='Beta')
ax.hist(lognormal_samples[lognormal_samples < 5], bins=50, density=True, alpha=0.5, color='red', label='LogNormal')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Distribution Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Independent vs Broadcast Sampling

NumPyroAxis supports two sampling modes for different modeling scenarios:

key = jax.random.key(456)

# Independent mode: each element sampled independently
independent_axis = NumPyroAxis(
    dist.Normal(0.0, 1.0),
    n=3,
    sample_shape=(2, 4),
    broadcast_mode=False
)
independent_samples = independent_axis.generate_values(key)

# Broadcast mode: one sample per axis point, broadcast to shape
broadcast_axis = NumPyroAxis(
    dist.Normal(0.0, 1.0),
    n=3,
    sample_shape=(2, 4),
    broadcast_mode=True
)
broadcast_samples = broadcast_axis.generate_values(key)

print("Independent sampling:")
print(f"Shape: {independent_samples.shape}")
print("Sample 0 (each element different):")
print(independent_samples[0])
print()

print("Broadcast sampling:")
print(f"Shape: {broadcast_samples.shape}")
print("Sample 0 (all elements identical):")
print(broadcast_samples[0])
print(f"All elements identical: {jnp.allclose(broadcast_samples[0], broadcast_samples[0].flatten()[0])}")
Independent sampling:
Shape: (3, 2, 4)
Sample 0 (each element different):
[[ 0.991294   -1.2127206  -0.28042838  0.02664931]
 [-0.07197085  0.3858999  -0.5076543  -1.5987304 ]]

Broadcast sampling:
Shape: (3, 2, 4)
Sample 0 (all elements identical):
[[0.991294 0.991294 0.991294 0.991294]
 [0.991294 0.991294 0.991294 0.991294]]
All elements identical: True
Show visualization code
# Visualize the difference
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Independent sampling heatmap
im1 = ax1.imshow(independent_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax1.set_title('Independent Sampling\n(Each element different)')
ax1.set_xlabel('Dimension 2')
ax1.set_ylabel('Dimension 1')
plt.colorbar(im1, ax=ax1, fraction=0.046)

# Broadcast sampling heatmap
im2 = ax2.imshow(broadcast_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax2.set_title('Broadcast Sampling\n(All elements identical)')
ax2.set_xlabel('Dimension 2')
ax2.set_ylabel('Dimension 1')
plt.colorbar(im2, ax=ax2, fraction=0.046)

plt.tight_layout()
plt.show()

Composing with Space

Basic Space Creation

The Space class is where the magic happens - it takes a state tree containing axes and fixed values, then creates all possible parameter combinations according to your specified strategy.

# Create a realistic brain simulation parameter space
simulation_state = {
    'coupling': {
        'strength': GridAxis(0.1, 0.8, 5),      # Coupling strength
        'delay': UniformAxis(5.0, 25.0, 3),     # Transmission delays (ms)
    },
    'model': {
        'noise_amplitude': DataAxis([0.01, 0.02, 0.05]),  # Noise levels from pilot study
        'time_constant': 10.0,                    # Fixed parameter
    },
    'simulation': {
        'dt': 0.1,                               # Fixed time step
        'duration': 1000,                        # Fixed duration
    }
}

# Create the space
space = Space(simulation_state, mode='product', key=jax.random.key(789))
print(f"Parameter space contains {len(space)} combinations")
print(f"Combination modes: 5 × 3 × 3 = {5*3*3} (Grid × Uniform × Data)")

# Look at a few parameter combinations
print("\nFirst 3 parameter combinations:")
print("=" * 50)

for i, params in enumerate(space):
    if i >= 3: break
    print(f"\nCombination {i}:")
    print(f"  Coupling strength: {params['coupling']['strength']:.3f}")
    print(f"  Coupling delay: {params['coupling']['delay']:.2f} ms")
    print(f"  Noise amplitude: {params['model']['noise_amplitude']:.3f}")
    print(f"  Time constant: {params['model']['time_constant']} (fixed)")
    print(f"  Simulation dt: {params['simulation']['dt']} (fixed)")
Parameter space contains 45 combinations
Combination modes: 5 × 3 × 3 = 45 (Grid × Uniform × Data)

First 3 parameter combinations:
==================================================

Combination 0:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.010
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination 1:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.020
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination 2:
  Coupling strength: 0.100
  Coupling delay: 16.82 ms
  Noise amplitude: 0.050
  Time constant: 10.0 (fixed)
  Simulation dt: 0.1 (fixed)

Combination Modes: Product vs Zip

The two combination modes serve different exploration strategies:

Product Mode (Cartesian Product)

Perfect for systematic exploration - tests every combination of parameter values.

# Product mode: systematic exploration
product_state = {
    'param_a': GridAxis(0.0, 1.0, 3),
    'param_b': GridAxis(0.0, 1.0, 2),
}

product_space = Space(product_state, mode='product')
print(f"Product mode: {len(product_space)} combinations")

print("\nAll combinations (product mode):")
for i, params in enumerate(product_space):
    print(f"  {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
Product mode: 6 combinations

All combinations (product mode):
  0: a=0.0, b=0.0
  1: a=0.0, b=1.0
  2: a=0.5, b=0.0
  3: a=0.5, b=1.0
  4: a=1.0, b=0.0
  5: a=1.0, b=1.0

Zip Mode (Parallel Sampling)

Perfect for matched sampling - pairs corresponding elements from each axis.

# Zip mode: matched sampling
zip_space = Space(product_state, mode='zip')
print(f"Zip mode: {len(zip_space)} combinations (uses minimum axis size)")

print("\nMatched combinations (zip mode):")
for i, params in enumerate(zip_space):
    print(f"  {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
WARNING: In zip mode, effective axes have different sizes [3, 2]. Using minimum size 2, losing 1 combinations.
Zip mode: 2 combinations (uses minimum axis size)

Matched combinations (zip mode):
  0: a=0.0, b=0.0
  1: a=0.5, b=1.0
Show visualization code
# Create a more interesting comparison
comparison_state = {
    'x': GridAxis(0.0, 2.0, 8),
    'y': GridAxis(0.0, 2.0, 5),
}

product_comp = Space(comparison_state, mode='product')
zip_comp = Space(comparison_state, mode='zip')

# Extract coordinates for plotting
product_x = [p['x'] for p in product_comp]
product_y = [p['y'] for p in product_comp]

zip_x = [p['x'] for p in zip_comp]
zip_y = [p['y'] for p in zip_comp]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 5))

# Product mode visualization
ax1.scatter(product_x, product_y, color='blue', s=50, alpha=0.7)
ax1.set_xlabel('Parameter X')
ax1.set_ylabel('Parameter Y')
ax1.set_title(f'Product Mode\n({len(product_comp)} combinations)')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-0.2, 2.2)
ax1.set_ylim(-0.2, 2.2)

# Zip mode visualization
ax2.scatter(zip_x, zip_y, color='red', s=100, alpha=0.7)
ax2.plot(zip_x, zip_y, 'r--', alpha=0.5, linewidth=1)
ax2.set_xlabel('Parameter X')
ax2.set_ylabel('Parameter Y')
ax2.set_title(f'Zip Mode\n({len(zip_comp)} combinations)')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(-0.2, 2.2)
ax2.set_ylim(-0.2, 2.2)

plt.tight_layout()
plt.show()
WARNING: In zip mode, effective axes have different sizes [8, 5]. Using minimum size 5, losing 3 combinations.

Grouped Axes: Mixing Combination Modes

In many real-world scenarios, you need both zip and product behavior in the same space. For example, you might have matched experimental settings that must move in lockstep, combined with a parameter grid you want to explore exhaustively.

The group parameter on axes solves this. Axes sharing the same group label are zipped together into a single composite axis, regardless of where they sit in the state tree. Then all groups and ungrouped axes combine using the Space’s mode (typically product).

Basic Example

# Three matched settings that must stay together
setting_values_a = jnp.array([1.0, 2.0, 3.0])
setting_values_b = jnp.array([10.0, 20.0, 30.0])

grouped_state = {
    'setting_a': DataAxis(setting_values_a, group='settings'),
    'setting_b': DataAxis(setting_values_b, group='settings'),
    'param_x': GridAxis(0.0, 1.0, 4),
}

space = Space(grouped_state, mode='product', key=jax.random.key(42))
print(f"Total combinations: {len(space)}")
print(f"  = 3 settings x 4 param_x = {3 * 4}")
print(f"  (NOT 3 x 3 x 4 = {3 * 3 * 4})")

print("\nAll combinations:")
for i, params in enumerate(space):
    print(f"  {i}: setting_a={params['setting_a']:.0f}, "
          f"setting_b={params['setting_b']:.0f}, "
          f"param_x={params['param_x']:.2f}")
Total combinations: 12
  = 3 settings x 4 param_x = 12
  (NOT 3 x 3 x 4 = 36)

All combinations:
  0: setting_a=1, setting_b=10, param_x=0.00
  1: setting_a=1, setting_b=10, param_x=0.33
  2: setting_a=1, setting_b=10, param_x=0.67
  3: setting_a=1, setting_b=10, param_x=1.00
  4: setting_a=2, setting_b=20, param_x=0.00
  5: setting_a=2, setting_b=20, param_x=0.33
  6: setting_a=2, setting_b=20, param_x=0.67
  7: setting_a=2, setting_b=20, param_x=1.00
  8: setting_a=3, setting_b=30, param_x=0.00
  9: setting_a=3, setting_b=30, param_x=0.33
  10: setting_a=3, setting_b=30, param_x=0.67
  11: setting_a=3, setting_b=30, param_x=1.00

Notice that setting_a and setting_b always move together (1 with 10, 2 with 20, 3 with 30) while param_x varies independently across its full grid.

Groups work across subtrees – grouped axes don’t need to be co-located in the state tree. You can also use multiple independent groups (each with its own label), and they each zip internally before being combined via the Space’s mode. The group label can be any hashable value (strings, ints, tuples, etc.).

Iteration & Access Patterns

Space provides flexible ways to access parameter combinations:

# Create a space for demonstration
demo_state = {
    'coupling': GridAxis(0.0, 1.0, 6),
    'noise': UniformAxis(0.01, 0.1, 6),
    'threshold': 0.5  # Fixed value
}

demo_space = Space(demo_state, mode='zip', key=jax.random.key(42))
print(f"Demo space has {len(demo_space)} combinations")
Demo space has 6 combinations

Sequential Iteration

# Standard iteration - perfect for simulations
print("Sequential iteration (first 3):")
for i, params in enumerate(demo_space):
    if i >= 3: break
    print(f"  Run {i}: coupling={params['coupling']:.2f}, "
          f"noise={params['noise']:.4f}")
Sequential iteration (first 3):
  Run 0: coupling=0.00, noise=0.0701
  Run 1: coupling=0.20, noise=0.0749
  Run 2: coupling=0.40, noise=0.0214

Random Access

# Direct indexing - great for debugging specific combinations
print("Random access examples:")
print(f"  Combination 0: coupling = {demo_space[0]['coupling']:.3f}")
print(f"  Combination 2: coupling = {demo_space[2]['coupling']:.3f}")
print(f"  Last combination: coupling = {demo_space[-1]['coupling']:.3f}")

# Negative indexing works too
print(f"  Second to last: coupling = {demo_space[-2]['coupling']:.3f}")
Random access examples:
  Combination 0: coupling = 0.000
  Combination 2: coupling = 0.400
  Last combination: coupling = 1.000
  Second to last: coupling = 0.800

Slicing for Subset Analysis

# Slicing creates new spaces - perfect for subset analysis
subset = demo_space[1:4]
print(f"Subset contains {len(subset)} combinations")

print("Subset combinations:")
for i, params in enumerate(subset):
    print(f"  Original index {i+1}: coupling={params['coupling']:.3f}")

# Slicing with steps
every_other = demo_space[::2]
print(f"\nEvery other combination ({len(every_other)} total):")
for i, params in enumerate(every_other):
    print(f"  Index {i}: coupling={params['coupling']:.3f}")
Subset contains 3 combinations
Subset combinations:
  Original index 1: coupling=0.200
  Original index 2: coupling=0.400
  Original index 3: coupling=0.600

Every other combination (3 total):
  Index 0: coupling=0.000
  Index 1: coupling=0.400
  Index 2: coupling=0.800

Converting to DataFrames

Both spaces and execution results can be converted to pandas DataFrames for easy analysis and plotting. Column names are derived automatically from the pytree structure of your state.

# Parameter space as a DataFrame
df_params = demo_space.to_dataframe()
print(df_params.head())
   coupling     noise
0       0.0  0.070052
1       0.2  0.074934
2       0.4  0.021412
3       0.6  0.041885
4       0.8  0.047898

After running a computation, Result.to_dataframe() combines parameters and results into a single DataFrame, keeping them correctly aligned:

from tvboptim.execution import SequentialExecution

def compute_metric(state):
    return {'activity': state['coupling'] * 2 + state['noise'],
            'ratio': state['coupling'] / (state['noise'] + 1e-6)}

result = SequentialExecution(compute_metric, demo_space).run()
df = result.to_dataframe()
print(df.head())

  0%|          | 0/6 [00:00<?, ?it/s]
100%|██████████| 6/6 [00:00<00:00, 76.79it/s]
   coupling     noise  activity      ratio
0       0.0  0.070052  0.070052   0.000000
1       0.2  0.074934  0.474934   2.668987
2       0.4  0.021412  0.821412  18.680656
3       0.6  0.041885  1.241885  14.324623
4       0.8  0.047898  1.647898  16.701727
Show plotting example
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

ax1.scatter(df['coupling'], df['activity'])
ax1.set_xlabel('Coupling')
ax1.set_ylabel('Activity')
ax1.set_title('Coupling vs Activity')
ax1.grid(True, alpha=0.3)

ax2.scatter(df['noise'], df['ratio'], c=df['coupling'], cmap='viridis')
ax2.set_xlabel('Noise')
ax2.set_ylabel('Ratio')
ax2.set_title('Noise vs Ratio (colored by Coupling)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Nested state structures produce dot-separated column names (e.g. model.G, noise.sigma), and non-scalar results like arrays or matrices are stored as object cells in the DataFrame.

Execution: Running Functions Over Spaces

Now that we understand how to create and manipulate parameter spaces, let’s see how to actually execute functions over them. TVB-Optim provides two execution strategies through SequentialExecution and ParallelExecution classes.

Sequential Execution

Sequential execution processes parameter combinations one at a time, making it ideal for debugging, memory-constrained environments, or when you need full control over execution order.

from tvboptim.execution import SequentialExecution

# Define a simple brain simulation model
def brain_simulation(state, noise_factor=1.0):
    """
    Simple brain simulation that combines coupling and noise.
    Returns both simulation results and metadata.
    """
    coupling = state['coupling']
    noise = state['noise'] * noise_factor

    # Simulate some dynamics (simplified for demonstration)
    activity = coupling * jnp.sin(jnp.linspace(0, 4*jnp.pi, 100)) + noise * jax.random.normal(jax.random.key(42), shape=(100,))

    # Calculate some metrics
    mean_activity = jnp.mean(activity)
    peak_activity = jnp.max(activity)

    return {
        'activity': activity,
        'metrics': {
            'mean': mean_activity,
            'peak': peak_activity,
            'coupling_used': coupling,
            'noise_used': noise
        }
    }

# Create a parameter space for our simulation
simulation_state = {
    'coupling': GridAxis(0.2, 1.5, 4),
    'noise': UniformAxis(0.01, 0.1, 4),
    'fixed_param': 42.0
}

space = Space(simulation_state, mode='zip', key=jax.random.key(123))
print(f"Parameter space: {len(space)} combinations")

# Set up sequential execution
sequential_executor = SequentialExecution(
    model=brain_simulation,
    statespace=space,
    noise_factor=1.5  # Additional parameter passed to model
)

print("\nExecuting simulations sequentially...")
Parameter space: 4 combinations

Executing simulations sequentially...
# Run the sequential execution (with progress bar)
sequential_results = sequential_executor.run()

print(f"Execution complete! Got {len(sequential_results)} results")

# Access results
first_result = sequential_results[0]
print(f"\nFirst result structure:")
print(f"  Activity shape: {first_result['activity'].shape}")
print(f"  Mean activity: {first_result['metrics']['mean']:.3f}")
print(f"  Peak activity: {first_result['metrics']['peak']:.3f}")
print(f"  Coupling used: {first_result['metrics']['coupling_used']:.3f}")

# Analyze all results
print(f"\nAll simulation metrics:")
for i, result in enumerate(sequential_results):
    metrics = result['metrics']
    print(f"  Run {i}: coupling={metrics['coupling_used']:.2f}, "
          f"noise={metrics['noise_used']:.4f}, mean={metrics['mean']:.3f}")

  0%|          | 0/4 [00:00<?, ?it/s]
 25%|██▌       | 1/4 [00:00<00:01,  2.79it/s]
100%|██████████| 4/4 [00:00<00:00, 11.06it/s]
Execution complete! Got 4 results

First result structure:
  Activity shape: (100,)
  Mean activity: 0.001
  Peak activity: 0.257
  Coupling used: 0.200

All simulation metrics:
  Run 0: coupling=0.20, noise=0.0397, mean=0.001
  Run 1: coupling=0.63, noise=0.1314, mean=0.004
  Run 2: coupling=1.07, noise=0.0306, mean=0.001
  Run 3: coupling=1.50, noise=0.0587, mean=0.002
Show visualization code
# Visualize sequential execution results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))

# Plot activity traces
for i, result in enumerate(sequential_results):
    activity = result['activity']
    coupling = result['metrics']['coupling_used']
    ax1.plot(activity[:50], alpha=0.7, label=f'Coupling={coupling:.2f}')

ax1.set_xlabel('Time Steps')
ax1.set_ylabel('Activity')
ax1.set_title('Brain Activity Traces (First 50 Steps)')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot coupling vs mean activity relationship
couplings = [r['metrics']['coupling_used'] for r in sequential_results]
means = [r['metrics']['mean'] for r in sequential_results]
ax2.scatter(couplings, means, s=100, alpha=0.7, color='blue')
ax2.set_xlabel('Coupling Strength')
ax2.set_ylabel('Mean Activity')
ax2.set_title('Coupling vs Mean Activity')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Parallel Execution

Parallel execution leverages JAX’s pmap and vmap transformations for efficient vectorized computation across multiple devices. This is essential for large-scale parameter explorations.

from tvboptim.execution import ParallelExecution

# Create a larger parameter space for parallel execution
parallel_state = {
    'coupling': GridAxis(0.1, 2.0, 8),
    'noise': UniformAxis(0.005, 0.05, 8),
    'threshold': 0.5  # Fixed parameter
}

parallel_space = Space(parallel_state, mode='product', key=jax.random.key(456))
print(f"Larger parameter space: {len(parallel_space)} combinations")

# Define a JAX-optimized model for parallel execution
def fast_brain_model(state):
    """
    JAX-optimized brain model for parallel execution.
    Simpler than sequential version for faster computation.
    """
    coupling = state['coupling']
    noise = state['noise']

    # Simplified dynamics that's faster to compute
    base_activity = coupling * jnp.sin(jnp.arange(20) * 0.5)
    noisy_activity = base_activity + noise * jax.random.normal(jax.random.key(123), shape=(20,))

    return {
        'mean_activity': jnp.mean(noisy_activity),
        'max_activity': jnp.max(noisy_activity),
        'params': state
    }

# Set up parallel execution
n_devices = jax.device_count()
print(f"Available devices: {n_devices}")

parallel_executor = ParallelExecution(
    model=fast_brain_model,
    space=parallel_space,
    n_vmap=8,  # Vectorize over 8 parameter combinations
    n_pmap=1   # Use 1 device (can use more if available)
)

print("\nExecuting simulations in parallel...")
Larger parameter space: 64 combinations
Available devices: 8

Executing simulations in parallel...
# Run parallel execution
parallel_results = parallel_executor.run()

print(f"Parallel execution complete! Got {len(parallel_results)} results")

# Access results - same interface as sequential
first_parallel = parallel_results[0]
print(f"\nFirst parallel result:")
print(f"  Mean activity: {first_parallel['mean_activity']:.3f}")
print(f"  Max activity: {first_parallel['max_activity']:.3f}")
print(f"  Coupling: {first_parallel['params']['coupling']:.3f}")

# Get subset of results
subset_results = parallel_results[5:15]
print(f"\nSubset contains {len(subset_results)} results")

# Analyze parameter-activity relationships
couplings = [r['params']['coupling'] for r in parallel_results]
means = [r['mean_activity'] for r in parallel_results]
max_vals = [r['max_activity'] for r in parallel_results]

print(f"\nParameter exploration summary:")
print(f"  Coupling range: [{min(couplings):.2f}, {max(couplings):.2f}]")
print(f"  Mean activity range: [{min(means):.3f}, {max(means):.3f}]")
print(f"  Max activity range: [{min(max_vals):.3f}, {max(max_vals):.3f}]")
Parallel execution complete! Got 64 results

First parallel result:
  Mean activity: 0.012
  Max activity: 0.140
  Coupling: 0.100

Subset contains 10 results

Parameter exploration summary:
  Coupling range: [0.10, 2.00]
  Mean activity range: [0.011, 0.386]
  Max activity range: [0.095, 1.990]
Show visualization code
# Visualize parallel execution results
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))

# 1. Coupling vs Mean Activity
ax = axes[0, 0]
ax.scatter(couplings, means, alpha=0.6, s=30)
ax.set_xlabel('Coupling Strength')
ax.set_ylabel('Mean Activity')
ax.set_title('Coupling vs Mean Activity')
ax.grid(True, alpha=0.3)

# 2. Noise vs Mean Activity
noises = [r['params']['noise'] for r in parallel_results]
ax = axes[0, 1]
ax.scatter(noises, means, alpha=0.6, s=30, color='orange')
ax.set_xlabel('Noise Level')
ax.set_ylabel('Mean Activity')
ax.set_title('Noise vs Mean Activity')
ax.grid(True, alpha=0.3)

# 3. 2D parameter space visualization
ax = axes[1, 0]
scatter = ax.scatter(couplings, noises, c=means, cmap='viridis', s=50)
ax.set_xlabel('Coupling Strength')
ax.set_ylabel('Noise Level')
ax.set_title('Parameter Space (colored by Mean Activity)')
plt.colorbar(scatter, ax=ax)

# 4. Mean vs Max Activity
ax = axes[1, 1]
ax.scatter(means, max_vals, alpha=0.6, s=30, color='red')
ax.set_xlabel('Mean Activity')
ax.set_ylabel('Max Activity')
ax.set_title('Mean vs Max Activity Relationship')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Performance Comparison & Best Practices

import time

# Compare execution times for different approaches
small_state = {
    'param1': GridAxis(0.0, 1.0, 5),
    'param2': UniformAxis(0.0, 1.0, 5)
}

small_space = Space(small_state, mode='product', key=jax.random.key(789))

def simple_model(state):
    return {'result': state['param1'] * state['param2']}

print("Performance comparison on 25 parameter combinations:")
print("=" * 55)

# Sequential timing
seq_exec = SequentialExecution(simple_model, small_space)
start = time.time()
seq_results = seq_exec.run()
seq_time = time.time() - start
print(f"Sequential execution: {seq_time:.4f} seconds")

# Parallel timing
par_exec = ParallelExecution(simple_model, small_space, n_vmap=5, n_pmap=1)
start = time.time()
par_results = par_exec.run()
par_time = time.time() - start
print(f"Parallel execution:   {par_time:.4f} seconds")

if seq_time > par_time:
    speedup = seq_time / par_time
    print(f"Parallel speedup:     {speedup:.1f}x faster")
else:
    print("Sequential was faster (overhead dominates for small problems)")

print(f"\nBoth produced {len(seq_results)} results")
Performance comparison on 25 parameter combinations:
=======================================================

  0%|          | 0/25 [00:00<?, ?it/s]
  4%|▍         | 1/25 [00:00<00:04,  5.72it/s]
100%|██████████| 25/25 [00:00<00:00, 136.43it/s]
Sequential execution: 0.1857 seconds
Parallel execution:   0.0468 seconds
Parallel speedup:     4.0x faster

Both produced 25 results

When to Use Each Approach

Sequential Execution:

  • Debugging: Easy to trace through individual parameter combinations
  • Memory constraints: Processes one combination at a time
  • Complex models: When models don’t easily vectorize
  • Small parameter spaces: Overhead of parallelization not worth it

Parallel Execution:

  • Large parameter spaces: Massive speedups for hundreds/thousands of combinations
  • JAX-compatible models: Models that can be easily vectorized
  • Production runs: When you need results fast
  • Multi-device setups: Can leverage multiple GPUs/TPUs

Integration with Spaces

Both execution types work seamlessly with all space features:

# Works with any axis type
complex_state = {
    'coupling': NumPyroAxis(dist.Beta(2.0, 5.0), n=3),  # Probabilistic
    'delay': DataAxis([5.0, 10.0, 15.0]),               # Specific values
    'noise': GridAxis(0.01, 0.1, 3),                    # Systematic
    'region_params': 68.0                               # Fixed value
}

complex_space = Space(complex_state, mode='product', key=jax.random.key(999))
print(f"Complex space: {len(complex_space)} combinations")

# Both executors handle this seamlessly
print("✓ Sequential execution: handles any space complexity")
print("✓ Parallel execution: vectorizes over any parameter combination")
print("✓ Results: same access patterns regardless of execution type")
Complex space: 27 combinations
✓ Sequential execution: handles any space complexity
✓ Parallel execution: vectorizes over any parameter combination
✓ Results: same access patterns regardless of execution type

This execution system provides the final piece of the TVB-Optim parameter exploration pipeline - taking your carefully designed parameter spaces and efficiently computing results across all combinations, whether you need the control of sequential execution or the speed of parallel execution.