spaces.Space

spaces.Space(state, mode='zip', key=None)

Composable parameter space built from multiple axes.

Space discovers AbstractAxis instances in a parameter state tree and composes them to create parameter combinations. Supports both product (Cartesian product) and zip (parallel) combination modes.

The Space class provides efficient parameter exploration by:

  • Automatic axis discovery: Finds AbstractAxis instances in state trees
  • Flexible combination modes: Product mode for full grid search, zip mode for parallel sampling
  • Efficient iteration: Pre-generates combinations for fast access
  • Slicing support: Create subspaces with space[start:end] syntax

Parameters

Name Type Description Default
state dict State tree containing AbstractAxis instances and fixed values. required
mode str Combination mode: ‘product’ for Cartesian product, ‘zip’ for parallel. Default is ‘zip’. 'zip'
key jax.random.PRNGKey Random key for stochastic axes. Default creates new key. None

Raises

Name Type Description
ValueError If mode is not ‘product’ or ‘zip’, or if no AbstractAxis instances found.

Examples

>>> import jax.numpy as jnp
>>> import copy
>>> from tvboptim.types.spaces import Space, GridAxis, UniformAxis
>>>
>>> # Create exploration state
>>> exploration_state = copy.deepcopy(base_state)
>>> exploration_state.parameters.coupling.a = GridAxis(0.0, 1.0, 5)
>>> exploration_state.parameters.model.J_N = UniformAxis(0.0, 1.0, 3)
>>>
>>> # Product mode: 5 × 3 = 15 combinations
>>> space = Space(exploration_state, mode='product')
>>> print(f"Total combinations: {space.N}")
>>>
>>> # Access individual combinations
>>> first_state = space[0]
>>> print(first_state.parameters.coupling.a)
>>>
>>> # Create subspace
>>> subset = space[2:8]
>>> print(f"Subset size: {len(subset)}")
>>>
>>> # Parallel execution
>>> batched_states = space.collect(n_vmap=4, n_pmap=2)

Attributes

Name Description
N Total number of parameter combinations.

Methods

Name Description
collect Generate batched states for parallel execution.

collect

spaces.Space.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan, combine=True)

Generate batched states for parallel execution.

Creates parameter combinations organized for efficient JAX parallel execution using vmap and pmap. This method is essential for high-performance parameter exploration on modern accelerators.

Parameters

Name Type Description Default
n_vmap int Number of states to vectorize over using vmap. If None, defaults to 1. None
n_pmap int Number of devices for parallel mapping with pmap. If None, defaults to 1. None
fill_value float Value used for padding when reshaping to target dimensions. Default is jnp.nan. jnp.nan
combine bool Whether to combine the static state with the parameter combinations into a complete state. If False, returns (axis_tree, static_state) tuple. Default is True. True

Returns

Name Type Description
dict or tuple If combine=True: State tree with batched parameter combinations shaped for parallel execution with dimensions (n_pmap, n_vmap, n_map). If combine=False: Tuple of (batched_axis_tree, static_state).

Warnings

If the total requested size (n_pmap * n_vmap * n_map) exceeds the number of combinations N, padding with fill_value will be used.

Examples

>>> # Create space with 100 combinations
>>> space = Space(exploration_state, mode='product')
>>>
>>> # Batch for parallel execution: 2 devices, 8 vectors each
>>> batched_states = space.collect(n_vmap=8, n_pmap=2)
>>> print(batched_states.shape)  # (2, 8, n_map)
>>>
>>> # Use with JAX transformations
>>> results = jax.pmap(jax.vmap(simulation_fn))(batched_states)