---
title: "Axes and Spaces: Systematic Parameter Exploration in TVB-Optim"
format:
html:
code-fold: false
toc: true
toc-depth: 3
fig-width: 8
out-width: "100%"
jupyter: python3
execute:
cache: true
---
# Introduction & Overview
The Axes and Spaces system in TVB-Optim provides a powerful, composable framework for systematic parameter exploration in brain network simulations. This system enables you to:
- **Define parameter sampling strategies** using different axis types
- **Compose complex parameter spaces** from multiple axes
- **Leverage JAX transformations** for efficient parallel execution
- **Seamlessly integrate** with optimization and analysis workflows
## Why This Design?
Traditional parameter exploration often requires writing custom loops and managing complex indexing. TVB-Optim's axis system provides a flexible and composable approach to make things easier.
Let's start with a quick preview of what we'll build:
```{python}
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from tvboptim.types.spaces import Space, GridAxis, UniformAxis, DataAxis, NumPyroAxis
# Quick preview: A mixed parameter space
preview_state = {
'coupling_strength': GridAxis(0.0, 2.0, 4),
'noise_level': UniformAxis(0.01, 0.1, 3),
'fixed_param': 42.0
}
space = Space(preview_state, mode='product', key=jax.random.key(123))
print(f"Created parameter space with {len(space)} combinations")
# Show first few parameter combinations
for i, params in enumerate(space):
if i >= 5: break
print(f"Combination {i}: coupling={params['coupling_strength']:.2f}, "
f"noise={params['noise_level']:.4f}, fixed={params['fixed_param']}")
```
# Understanding Axes
## The AbstractAxis Interface
All axes in TVB-Optim implement a simple but powerful interface:
- **`generate_values(key=None)`**: Produces sample values as JAX arrays
- **`size`**: Returns the number of samples this axis generates
```{python}
# Let's examine the interface
from tvboptim.types.spaces import GridAxis
grid = GridAxis(0.0, 1.0, 5)
print(f"Axis size: {grid.size}")
print(f"Generated values: {grid.generate_values()}")
print(f"Values shape: {grid.generate_values().shape}")
```
Each axis type can implement completely different sampling strategies while maintaining a consistent interface.
## GridAxis - Deterministic Sampling
GridAxis provides systematic, deterministic sampling across parameter ranges. Perfect for sensitivity analysis and systematic exploration.
```{python}
# Basic grid sampling
grid_basic = GridAxis(0.0, 2.0, 11)
values_basic = grid_basic.generate_values()
print(f"Grid values: {values_basic}")
print(f"Spacing is uniform: {jnp.allclose(jnp.diff(values_basic), jnp.diff(values_basic)[0])}")
# Multiple grid densities
densities = [5, 10, 20]
for n in densities:
grid = GridAxis(0.0, 1.0, n)
values = grid.generate_values()
print(f"n={n:2d}: {len(values)} values from {values[0]:.3f} to {values[-1]:.3f}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize grid sampling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
# Plot 1: Basic grid values
ax1.scatter(range(len(values_basic)), values_basic, color='blue', s=50)
ax1.set_xlabel('Sample Index')
ax1.set_ylabel('Parameter Value')
ax1.set_title('GridAxis: Linear Spacing')
ax1.grid(True, alpha=0.3)
# Plot 2: Multiple grid densities
colors = ['red', 'green', 'blue']
for i, (n, color) in enumerate(zip(densities, colors)):
grid = GridAxis(0.0, 1.0, n)
values = grid.generate_values()
ax2.scatter(values, [i] * len(values), color=color, s=30,
label=f'n={n}', alpha=0.7)
ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Grid Density')
ax2.set_title('GridAxis: Different Densities')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
### Shape Broadcasting with GridAxis
GridAxis supports shape broadcasting, which is useful for regional brain parameters in subsequent optimizations:
```{python}
# Example: Regional coupling strengths (68 brain regions)
n_regions = 68
regional_grid = GridAxis(0.0, 1.0, 5, shape=(n_regions,))
regional_values = regional_grid.generate_values()
print(f"Regional grid shape: {regional_values.shape}")
print(f"Each sample has shape: {regional_values[0].shape}")
print(f"Sample 2 value: {regional_values[2][0]:.3f} (broadcasted to all regions)")
print(f"All regions identical per sample: {jnp.allclose(regional_values[2], regional_values[2][0])}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize the broadcasting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
# Show that each sample broadcasts the same value
sample_idx = 2
ax1.plot(regional_values[sample_idx], 'o-', markersize=3)
ax1.set_xlabel('Brain Region')
ax1.set_ylabel('Coupling Strength')
ax1.set_title(f'Sample {sample_idx}: Broadcasted Value = {regional_values[sample_idx][0]:.2f}')
# Show progression across samples
ax2.plot(range(5), regional_values[:, 0], 'o-', linewidth=2, markersize=8)
ax2.set_xlabel('Sample Index')
ax2.set_ylabel('Parameter Value')
ax2.set_title('Parameter Progression Across Samples')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
## UniformAxis - Random Sampling
UniformAxis provides random sampling from uniform distributions, essential for more efficient stochastic exploration.
```{python}
# Reproducible random sampling
key = jax.random.key(42)
uniform = UniformAxis(0.0, 1.0, 100)
values1 = uniform.generate_values(key)
values2 = uniform.generate_values(key) # Same key = same values
values3 = uniform.generate_values(jax.random.key(43)) # Different key
print(f"Same key gives identical results: {jnp.allclose(values1, values2)}")
print(f"Different key gives different results: {not jnp.allclose(values1, values3)}")
print(f"Mean value: {jnp.mean(values1):.3f} (should be ~0.5)")
print(f"Value range: [{jnp.min(values1):.3f}, {jnp.max(values1):.3f}]")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize distributions
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
# Histogram of uniform samples
ax1.hist(values1, bins=20, alpha=0.7, color='green', density=True)
ax1.axhline(y=1.0, color='red', linestyle='--', label='Expected density = 1.0')
ax1.set_xlabel('Parameter Value')
ax1.set_ylabel('Density')
ax1.set_title('UniformAxis Distribution')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Comparison of different sample sizes
sample_sizes = [10, 50, 200]
for size in sample_sizes:
uniform_temp = UniformAxis(-2.0, 2.0, size)
samples = uniform_temp.generate_values(jax.random.key(42))
ax2.scatter(samples, [size] * len(samples), alpha=0.6, s=20,
label=f'n={size}')
ax2.set_xlabel('Parameter Value')
ax2.set_ylabel('Sample Size')
ax2.set_title('UniformAxis: Different Sample Sizes')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
### Shape Broadcasting with UniformAxis
UniformAxis also supports shape broadcasting, useful for parameters that need the same random value across multiple dimensions:
```{python}
# Compare independent vs broadcast sampling
key = jax.random.key(123)
# Independent sampling: each element is different
uniform_independent = UniformAxis(0.0, 1.0, 3, shape=None)
independent_values = uniform_independent.generate_values(key)
# Broadcast sampling: same value across shape dimensions
uniform_broadcast = UniformAxis(0.0, 1.0, 3, shape=(4, 2))
broadcast_values = uniform_broadcast.generate_values(key)
print("Independent sampling:")
print(f"Shape: {independent_values.shape}")
print(f"Values: {independent_values}")
print("\nBroadcast sampling:")
print(f"Shape: {broadcast_values.shape}")
print("First sample (4x2 matrix):")
print(broadcast_values[0])
print(f"All values in first sample identical: {jnp.allclose(broadcast_values[0], broadcast_values[0].flatten()[0])}")
```
## DataAxis - Predefined Values
DataAxis lets you use predefined sequences of values.
```{python}
# Example: Testing specific coupling values from literature
literature_values = jnp.array([0.142, 0.284, 0.426, 0.568, 0.710])
data_axis = DataAxis(literature_values)
print(f"DataAxis size: {data_axis.size}")
print(f"Values: {data_axis.generate_values()}")
# Works with multidimensional data too
connectivity_matrices = jnp.array([
[[1.0, 0.5], [0.5, 1.0]], # Weak coupling
[[1.0, 0.8], [0.8, 1.0]], # Medium coupling
[[1.0, 1.2], [1.2, 1.0]] # Strong coupling
])
matrix_axis = DataAxis(connectivity_matrices)
print(f"\nMatrix axis shape: {matrix_axis.generate_values().shape}")
print("First connectivity matrix:")
print(matrix_axis.generate_values()[0])
# Create some creative data patterns
fib_values = jnp.array([1, 1, 2, 3, 5, 8, 13, 21])
fib_normalized = fib_values / fib_values.max()
print(f"\nFibonacci sequence (normalized): {fib_normalized}")
# Oscillatory pattern
t = jnp.linspace(0, 4*jnp.pi, 20)
oscillatory = 0.5 + 0.3 * jnp.sin(t)
print(f"Oscillatory pattern range: [{jnp.min(oscillatory):.3f}, {jnp.max(oscillatory):.3f}]")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize different data patterns
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))
# 1. Literature values
ax = axes[0, 0]
ax.scatter(range(len(literature_values)), literature_values,
color='red', s=100, marker='s')
ax.set_xlabel('Index')
ax.set_ylabel('Coupling Strength')
ax.set_title('DataAxis: Literature Values')
ax.grid(True, alpha=0.3)
# 2. Fibonacci sequence
fib_axis = DataAxis(fib_normalized)
ax = axes[0, 1]
ax.plot(fib_normalized, 'o-', color='purple', linewidth=2, markersize=8)
ax.set_xlabel('Index')
ax.set_ylabel('Normalized Value')
ax.set_title('DataAxis: Fibonacci Sequence')
ax.grid(True, alpha=0.3)
# 3. Oscillatory pattern
osc_axis = DataAxis(oscillatory)
ax = axes[1, 0]
ax.plot(oscillatory, 'o-', color='orange', linewidth=2, markersize=6)
ax.set_xlabel('Index')
ax.set_ylabel('Parameter Value')
ax.set_title('DataAxis: Oscillatory Pattern')
ax.grid(True, alpha=0.3)
# 4. Connectivity matrix visualization
ax = axes[1, 1]
im = ax.imshow(matrix_axis.generate_values()[1], cmap='viridis', vmin=0, vmax=1.2)
ax.set_title('DataAxis: Connectivity Matrix (Medium)')
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
plt.show()
```
## NumPyroAxis - Distribution Sampling
NumPyroAxis brings the full power of probability distributions to parameter exploration, enabling sophisticated uncertainty quantification and Bayesian workflows.
```{python}
import numpyro.distributions as dist
# Different distribution types
key = jax.random.key(42)
# Normal distribution - common for many biological parameters
normal_axis = NumPyroAxis(dist.Normal(0.5, 0.15), n=1000)
normal_samples = normal_axis.generate_values(key)
# Beta distribution - great for bounded parameters [0,1]
beta_axis = NumPyroAxis(dist.Beta(2.0, 5.0), n=1000)
beta_samples = beta_axis.generate_values(key)
# LogNormal - for positive-only parameters like time constants
lognormal_axis = NumPyroAxis(dist.LogNormal(0.0, 0.5), n=1000)
lognormal_samples = lognormal_axis.generate_values(key)
print(f"Normal samples: mean={jnp.mean(normal_samples):.3f}, std={jnp.std(normal_samples):.3f}")
print(f"Beta samples: mean={jnp.mean(beta_samples):.3f}, range=[{jnp.min(beta_samples):.3f}, {jnp.max(beta_samples):.3f}]")
print(f"LogNormal samples: mean={jnp.mean(lognormal_samples):.3f}, median={jnp.median(lognormal_samples):.3f}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize the distributions
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))
# Normal distribution
ax = axes[0, 0]
ax.hist(normal_samples, bins=50, density=True, alpha=0.7, color='blue')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Normal(0.5, 0.15)')
ax.grid(True, alpha=0.3)
# Beta distribution
ax = axes[0, 1]
ax.hist(beta_samples, bins=50, density=True, alpha=0.7, color='green')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Beta(2.0, 5.0)')
ax.grid(True, alpha=0.3)
# LogNormal distribution
ax = axes[1, 0]
ax.hist(lognormal_samples, bins=50, density=True, alpha=0.7, color='red')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('LogNormal(0.0, 0.5)')
ax.grid(True, alpha=0.3)
# Comparison of all three
ax = axes[1, 1]
ax.hist(normal_samples, bins=50, density=True, alpha=0.5, color='blue', label='Normal')
ax.hist(beta_samples, bins=50, density=True, alpha=0.5, color='green', label='Beta')
ax.hist(lognormal_samples[lognormal_samples < 5], bins=50, density=True, alpha=0.5, color='red', label='LogNormal')
ax.set_xlabel('Value')
ax.set_ylabel('Density')
ax.set_title('Distribution Comparison')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
### Independent vs Broadcast Sampling
NumPyroAxis supports two sampling modes for different modeling scenarios:
```{python}
key = jax.random.key(456)
# Independent mode: each element sampled independently
independent_axis = NumPyroAxis(
dist.Normal(0.0, 1.0),
n=3,
sample_shape=(2, 4),
broadcast_mode=False
)
independent_samples = independent_axis.generate_values(key)
# Broadcast mode: one sample per axis point, broadcast to shape
broadcast_axis = NumPyroAxis(
dist.Normal(0.0, 1.0),
n=3,
sample_shape=(2, 4),
broadcast_mode=True
)
broadcast_samples = broadcast_axis.generate_values(key)
print("Independent sampling:")
print(f"Shape: {independent_samples.shape}")
print("Sample 0 (each element different):")
print(independent_samples[0])
print()
print("Broadcast sampling:")
print(f"Shape: {broadcast_samples.shape}")
print("Sample 0 (all elements identical):")
print(broadcast_samples[0])
print(f"All elements identical: {jnp.allclose(broadcast_samples[0], broadcast_samples[0].flatten()[0])}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize the difference
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
# Independent sampling heatmap
im1 = ax1.imshow(independent_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax1.set_title('Independent Sampling\n(Each element different)')
ax1.set_xlabel('Dimension 2')
ax1.set_ylabel('Dimension 1')
plt.colorbar(im1, ax=ax1, fraction=0.046)
# Broadcast sampling heatmap
im2 = ax2.imshow(broadcast_samples[0], cmap='RdBu_r', vmin=-3, vmax=3)
ax2.set_title('Broadcast Sampling\n(All elements identical)')
ax2.set_xlabel('Dimension 2')
ax2.set_ylabel('Dimension 1')
plt.colorbar(im2, ax=ax2, fraction=0.046)
plt.tight_layout()
plt.show()
```
# Composing with Space
## Basic Space Creation
The `Space` class is where the magic happens - it takes a state tree containing axes and fixed values, then creates all possible parameter combinations according to your specified strategy.
```{python}
# Create a realistic brain simulation parameter space
simulation_state = {
'coupling': {
'strength': GridAxis(0.1, 0.8, 5), # Coupling strength
'delay': UniformAxis(5.0, 25.0, 3), # Transmission delays (ms)
},
'model': {
'noise_amplitude': DataAxis([0.01, 0.02, 0.05]), # Noise levels from pilot study
'time_constant': 10.0, # Fixed parameter
},
'simulation': {
'dt': 0.1, # Fixed time step
'duration': 1000, # Fixed duration
}
}
# Create the space
space = Space(simulation_state, mode='product', key=jax.random.key(789))
print(f"Parameter space contains {len(space)} combinations")
print(f"Combination modes: 5 × 3 × 3 = {5*3*3} (Grid × Uniform × Data)")
# Look at a few parameter combinations
print("\nFirst 3 parameter combinations:")
print("=" * 50)
for i, params in enumerate(space):
if i >= 3: break
print(f"\nCombination {i}:")
print(f" Coupling strength: {params['coupling']['strength']:.3f}")
print(f" Coupling delay: {params['coupling']['delay']:.2f} ms")
print(f" Noise amplitude: {params['model']['noise_amplitude']:.3f}")
print(f" Time constant: {params['model']['time_constant']} (fixed)")
print(f" Simulation dt: {params['simulation']['dt']} (fixed)")
```
## Combination Modes: Product vs Zip
The two combination modes serve different exploration strategies:
### Product Mode (Cartesian Product)
Perfect for **systematic exploration** - tests every combination of parameter values.
```{python}
# Product mode: systematic exploration
product_state = {
'param_a': GridAxis(0.0, 1.0, 3),
'param_b': GridAxis(0.0, 1.0, 2),
}
product_space = Space(product_state, mode='product')
print(f"Product mode: {len(product_space)} combinations")
print("\nAll combinations (product mode):")
for i, params in enumerate(product_space):
print(f" {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
```
### Zip Mode (Parallel Sampling)
Perfect for **matched sampling** - pairs corresponding elements from each axis.
```{python}
# Zip mode: matched sampling
zip_space = Space(product_state, mode='zip')
print(f"Zip mode: {len(zip_space)} combinations (uses minimum axis size)")
print("\nMatched combinations (zip mode):")
for i, params in enumerate(zip_space):
print(f" {i}: a={params['param_a']:.1f}, b={params['param_b']:.1f}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Create a more interesting comparison
comparison_state = {
'x': GridAxis(0.0, 2.0, 8),
'y': GridAxis(0.0, 2.0, 5),
}
product_comp = Space(comparison_state, mode='product')
zip_comp = Space(comparison_state, mode='zip')
# Extract coordinates for plotting
product_x = [p['x'] for p in product_comp]
product_y = [p['y'] for p in product_comp]
zip_x = [p['x'] for p in zip_comp]
zip_y = [p['y'] for p in zip_comp]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 5))
# Product mode visualization
ax1.scatter(product_x, product_y, color='blue', s=50, alpha=0.7)
ax1.set_xlabel('Parameter X')
ax1.set_ylabel('Parameter Y')
ax1.set_title(f'Product Mode\n({len(product_comp)} combinations)')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(-0.2, 2.2)
ax1.set_ylim(-0.2, 2.2)
# Zip mode visualization
ax2.scatter(zip_x, zip_y, color='red', s=100, alpha=0.7)
ax2.plot(zip_x, zip_y, 'r--', alpha=0.5, linewidth=1)
ax2.set_xlabel('Parameter X')
ax2.set_ylabel('Parameter Y')
ax2.set_title(f'Zip Mode\n({len(zip_comp)} combinations)')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(-0.2, 2.2)
ax2.set_ylim(-0.2, 2.2)
plt.tight_layout()
plt.show()
```
## Grouped Axes: Mixing Combination Modes
In many real-world scenarios, you need **both** zip and product behavior in the same space. For example, you might have matched experimental settings that must move in lockstep, combined with a parameter grid you want to explore exhaustively.
The `group` parameter on axes solves this. Axes sharing the same `group` label are **zipped together** into a single composite axis, regardless of where they sit in the state tree. Then all groups and ungrouped axes combine using the Space's mode (typically product).
### Basic Example
```{python}
# Three matched settings that must stay together
setting_values_a = jnp.array([1.0, 2.0, 3.0])
setting_values_b = jnp.array([10.0, 20.0, 30.0])
grouped_state = {
'setting_a': DataAxis(setting_values_a, group='settings'),
'setting_b': DataAxis(setting_values_b, group='settings'),
'param_x': GridAxis(0.0, 1.0, 4),
}
space = Space(grouped_state, mode='product', key=jax.random.key(42))
print(f"Total combinations: {len(space)}")
print(f" = 3 settings x 4 param_x = {3 * 4}")
print(f" (NOT 3 x 3 x 4 = {3 * 3 * 4})")
print("\nAll combinations:")
for i, params in enumerate(space):
print(f" {i}: setting_a={params['setting_a']:.0f}, "
f"setting_b={params['setting_b']:.0f}, "
f"param_x={params['param_x']:.2f}")
```
Notice that `setting_a` and `setting_b` always move together (1 with 10, 2 with 20, 3 with 30) while `param_x` varies independently across its full grid.
Groups work across subtrees -- grouped axes don't need to be co-located in the state tree. You can also use multiple independent groups (each with its own label), and they each zip internally before being combined via the Space's mode. The `group` label can be any hashable value (strings, ints, tuples, etc.).
## Iteration & Access Patterns
Space provides flexible ways to access parameter combinations:
```{python}
# Create a space for demonstration
demo_state = {
'coupling': GridAxis(0.0, 1.0, 6),
'noise': UniformAxis(0.01, 0.1, 6),
'threshold': 0.5 # Fixed value
}
demo_space = Space(demo_state, mode='zip', key=jax.random.key(42))
print(f"Demo space has {len(demo_space)} combinations")
```
### Sequential Iteration
```{python}
# Standard iteration - perfect for simulations
print("Sequential iteration (first 3):")
for i, params in enumerate(demo_space):
if i >= 3: break
print(f" Run {i}: coupling={params['coupling']:.2f}, "
f"noise={params['noise']:.4f}")
```
### Random Access
```{python}
# Direct indexing - great for debugging specific combinations
print("Random access examples:")
print(f" Combination 0: coupling = {demo_space[0]['coupling']:.3f}")
print(f" Combination 2: coupling = {demo_space[2]['coupling']:.3f}")
print(f" Last combination: coupling = {demo_space[-1]['coupling']:.3f}")
# Negative indexing works too
print(f" Second to last: coupling = {demo_space[-2]['coupling']:.3f}")
```
### Slicing for Subset Analysis
```{python}
# Slicing creates new spaces - perfect for subset analysis
subset = demo_space[1:4]
print(f"Subset contains {len(subset)} combinations")
print("Subset combinations:")
for i, params in enumerate(subset):
print(f" Original index {i+1}: coupling={params['coupling']:.3f}")
# Slicing with steps
every_other = demo_space[::2]
print(f"\nEvery other combination ({len(every_other)} total):")
for i, params in enumerate(every_other):
print(f" Index {i}: coupling={params['coupling']:.3f}")
```
### Converting to DataFrames
Both spaces and execution results can be converted to pandas DataFrames for easy analysis and plotting. Column names are derived automatically from the pytree structure of your state.
```{python}
# Parameter space as a DataFrame
df_params = demo_space.to_dataframe()
print(df_params.head())
```
After running a computation, `Result.to_dataframe()` combines parameters and results into a single DataFrame, keeping them correctly aligned:
```{python}
from tvboptim.execution import SequentialExecution
def compute_metric(state):
return {'activity': state['coupling'] * 2 + state['noise'],
'ratio': state['coupling'] / (state['noise'] + 1e-6)}
result = SequentialExecution(compute_metric, demo_space).run()
df = result.to_dataframe()
print(df.head())
```
```{python}
#| code-fold: true
#| code-summary: "Show plotting example"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
ax1.scatter(df['coupling'], df['activity'])
ax1.set_xlabel('Coupling')
ax1.set_ylabel('Activity')
ax1.set_title('Coupling vs Activity')
ax1.grid(True, alpha=0.3)
ax2.scatter(df['noise'], df['ratio'], c=df['coupling'], cmap='viridis')
ax2.set_xlabel('Noise')
ax2.set_ylabel('Ratio')
ax2.set_title('Noise vs Ratio (colored by Coupling)')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
Nested state structures produce dot-separated column names (e.g. `model.G`, `noise.sigma`), and non-scalar results like arrays or matrices are stored as object cells in the DataFrame.
# Execution: Running Functions Over Spaces
Now that we understand how to create and manipulate parameter spaces, let's see how to actually execute functions over them. TVB-Optim provides two execution strategies through `SequentialExecution` and `ParallelExecution` classes.
## Sequential Execution
Sequential execution processes parameter combinations one at a time, making it ideal for debugging, memory-constrained environments, or when you need full control over execution order.
```{python}
from tvboptim.execution import SequentialExecution
# Define a simple brain simulation model
def brain_simulation(state, noise_factor=1.0):
"""
Simple brain simulation that combines coupling and noise.
Returns both simulation results and metadata.
"""
coupling = state['coupling']
noise = state['noise'] * noise_factor
# Simulate some dynamics (simplified for demonstration)
activity = coupling * jnp.sin(jnp.linspace(0, 4*jnp.pi, 100)) + noise * jax.random.normal(jax.random.key(42), shape=(100,))
# Calculate some metrics
mean_activity = jnp.mean(activity)
peak_activity = jnp.max(activity)
return {
'activity': activity,
'metrics': {
'mean': mean_activity,
'peak': peak_activity,
'coupling_used': coupling,
'noise_used': noise
}
}
# Create a parameter space for our simulation
simulation_state = {
'coupling': GridAxis(0.2, 1.5, 4),
'noise': UniformAxis(0.01, 0.1, 4),
'fixed_param': 42.0
}
space = Space(simulation_state, mode='zip', key=jax.random.key(123))
print(f"Parameter space: {len(space)} combinations")
# Set up sequential execution
sequential_executor = SequentialExecution(
model=brain_simulation,
statespace=space,
noise_factor=1.5 # Additional parameter passed to model
)
print("\nExecuting simulations sequentially...")
```
```{python}
# Run the sequential execution (with progress bar)
sequential_results = sequential_executor.run()
print(f"Execution complete! Got {len(sequential_results)} results")
# Access results
first_result = sequential_results[0]
print(f"\nFirst result structure:")
print(f" Activity shape: {first_result['activity'].shape}")
print(f" Mean activity: {first_result['metrics']['mean']:.3f}")
print(f" Peak activity: {first_result['metrics']['peak']:.3f}")
print(f" Coupling used: {first_result['metrics']['coupling_used']:.3f}")
# Analyze all results
print(f"\nAll simulation metrics:")
for i, result in enumerate(sequential_results):
metrics = result['metrics']
print(f" Run {i}: coupling={metrics['coupling_used']:.2f}, "
f"noise={metrics['noise_used']:.4f}, mean={metrics['mean']:.3f}")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize sequential execution results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.1, 4))
# Plot activity traces
for i, result in enumerate(sequential_results):
activity = result['activity']
coupling = result['metrics']['coupling_used']
ax1.plot(activity[:50], alpha=0.7, label=f'Coupling={coupling:.2f}')
ax1.set_xlabel('Time Steps')
ax1.set_ylabel('Activity')
ax1.set_title('Brain Activity Traces (First 50 Steps)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot coupling vs mean activity relationship
couplings = [r['metrics']['coupling_used'] for r in sequential_results]
means = [r['metrics']['mean'] for r in sequential_results]
ax2.scatter(couplings, means, s=100, alpha=0.7, color='blue')
ax2.set_xlabel('Coupling Strength')
ax2.set_ylabel('Mean Activity')
ax2.set_title('Coupling vs Mean Activity')
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
## Parallel Execution
Parallel execution leverages JAX's `pmap` and `vmap` transformations for efficient vectorized computation across multiple devices. This is essential for large-scale parameter explorations.
```{python}
from tvboptim.execution import ParallelExecution
# Create a larger parameter space for parallel execution
parallel_state = {
'coupling': GridAxis(0.1, 2.0, 8),
'noise': UniformAxis(0.005, 0.05, 8),
'threshold': 0.5 # Fixed parameter
}
parallel_space = Space(parallel_state, mode='product', key=jax.random.key(456))
print(f"Larger parameter space: {len(parallel_space)} combinations")
# Define a JAX-optimized model for parallel execution
def fast_brain_model(state):
"""
JAX-optimized brain model for parallel execution.
Simpler than sequential version for faster computation.
"""
coupling = state['coupling']
noise = state['noise']
# Simplified dynamics that's faster to compute
base_activity = coupling * jnp.sin(jnp.arange(20) * 0.5)
noisy_activity = base_activity + noise * jax.random.normal(jax.random.key(123), shape=(20,))
return {
'mean_activity': jnp.mean(noisy_activity),
'max_activity': jnp.max(noisy_activity),
'params': state
}
# Set up parallel execution
n_devices = jax.device_count()
print(f"Available devices: {n_devices}")
parallel_executor = ParallelExecution(
model=fast_brain_model,
space=parallel_space,
n_vmap=8, # Vectorize over 8 parameter combinations
n_pmap=1 # Use 1 device (can use more if available)
)
print("\nExecuting simulations in parallel...")
```
```{python}
# Run parallel execution
parallel_results = parallel_executor.run()
print(f"Parallel execution complete! Got {len(parallel_results)} results")
# Access results - same interface as sequential
first_parallel = parallel_results[0]
print(f"\nFirst parallel result:")
print(f" Mean activity: {first_parallel['mean_activity']:.3f}")
print(f" Max activity: {first_parallel['max_activity']:.3f}")
print(f" Coupling: {first_parallel['params']['coupling']:.3f}")
# Get subset of results
subset_results = parallel_results[5:15]
print(f"\nSubset contains {len(subset_results)} results")
# Analyze parameter-activity relationships
couplings = [r['params']['coupling'] for r in parallel_results]
means = [r['mean_activity'] for r in parallel_results]
max_vals = [r['max_activity'] for r in parallel_results]
print(f"\nParameter exploration summary:")
print(f" Coupling range: [{min(couplings):.2f}, {max(couplings):.2f}]")
print(f" Mean activity range: [{min(means):.3f}, {max(means):.3f}]")
print(f" Max activity range: [{min(max_vals):.3f}, {max(max_vals):.3f}]")
```
```{python}
#| code-fold: true
#| code-summary: "Show visualization code"
# Visualize parallel execution results
fig, axes = plt.subplots(2, 2, figsize=(8.1, 8))
# 1. Coupling vs Mean Activity
ax = axes[0, 0]
ax.scatter(couplings, means, alpha=0.6, s=30)
ax.set_xlabel('Coupling Strength')
ax.set_ylabel('Mean Activity')
ax.set_title('Coupling vs Mean Activity')
ax.grid(True, alpha=0.3)
# 2. Noise vs Mean Activity
noises = [r['params']['noise'] for r in parallel_results]
ax = axes[0, 1]
ax.scatter(noises, means, alpha=0.6, s=30, color='orange')
ax.set_xlabel('Noise Level')
ax.set_ylabel('Mean Activity')
ax.set_title('Noise vs Mean Activity')
ax.grid(True, alpha=0.3)
# 3. 2D parameter space visualization
ax = axes[1, 0]
scatter = ax.scatter(couplings, noises, c=means, cmap='viridis', s=50)
ax.set_xlabel('Coupling Strength')
ax.set_ylabel('Noise Level')
ax.set_title('Parameter Space (colored by Mean Activity)')
plt.colorbar(scatter, ax=ax)
# 4. Mean vs Max Activity
ax = axes[1, 1]
ax.scatter(means, max_vals, alpha=0.6, s=30, color='red')
ax.set_xlabel('Mean Activity')
ax.set_ylabel('Max Activity')
ax.set_title('Mean vs Max Activity Relationship')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
```
## Performance Comparison & Best Practices
```{python}
import time
# Compare execution times for different approaches
small_state = {
'param1': GridAxis(0.0, 1.0, 5),
'param2': UniformAxis(0.0, 1.0, 5)
}
small_space = Space(small_state, mode='product', key=jax.random.key(789))
def simple_model(state):
return {'result': state['param1'] * state['param2']}
print("Performance comparison on 25 parameter combinations:")
print("=" * 55)
# Sequential timing
seq_exec = SequentialExecution(simple_model, small_space)
start = time.time()
seq_results = seq_exec.run()
seq_time = time.time() - start
print(f"Sequential execution: {seq_time:.4f} seconds")
# Parallel timing
par_exec = ParallelExecution(simple_model, small_space, n_vmap=5, n_pmap=1)
start = time.time()
par_results = par_exec.run()
par_time = time.time() - start
print(f"Parallel execution: {par_time:.4f} seconds")
if seq_time > par_time:
speedup = seq_time / par_time
print(f"Parallel speedup: {speedup:.1f}x faster")
else:
print("Sequential was faster (overhead dominates for small problems)")
print(f"\nBoth produced {len(seq_results)} results")
```
### When to Use Each Approach
**Sequential Execution:**
- **Debugging**: Easy to trace through individual parameter combinations
- **Memory constraints**: Processes one combination at a time
- **Complex models**: When models don't easily vectorize
- **Small parameter spaces**: Overhead of parallelization not worth it
**Parallel Execution:**
- **Large parameter spaces**: Massive speedups for hundreds/thousands of combinations
- **JAX-compatible models**: Models that can be easily vectorized
- **Production runs**: When you need results fast
- **Multi-device setups**: Can leverage multiple GPUs/TPUs
### Integration with Spaces
Both execution types work seamlessly with all space features:
```{python}
# Works with any axis type
complex_state = {
'coupling': NumPyroAxis(dist.Beta(2.0, 5.0), n=3), # Probabilistic
'delay': DataAxis([5.0, 10.0, 15.0]), # Specific values
'noise': GridAxis(0.01, 0.1, 3), # Systematic
'region_params': 68.0 # Fixed value
}
complex_space = Space(complex_state, mode='product', key=jax.random.key(999))
print(f"Complex space: {len(complex_space)} combinations")
# Both executors handle this seamlessly
print("✓ Sequential execution: handles any space complexity")
print("✓ Parallel execution: vectorizes over any parameter combination")
print("✓ Results: same access patterns regardless of execution type")
```
This execution system provides the final piece of the TVB-Optim parameter exploration pipeline - taking your carefully designed parameter spaces and efficiently computing results across all combinations, whether you need the control of sequential execution or the speed of parallel execution.