spaces.Space
spaces.Space(state, mode='zip', key=None)Composable parameter space built from multiple axes.
Space discovers AbstractAxis instances in a parameter state tree and composes them to create parameter combinations. Supports both product (Cartesian product) and zip (parallel) combination modes.
The Space class provides efficient parameter exploration by:
- Automatic axis discovery: Finds AbstractAxis instances in state trees
- Flexible combination modes: Product mode for full grid search, zip mode for parallel sampling
- Efficient iteration: Pre-generates combinations for fast access
- Slicing support: Create subspaces with
space[start:end]syntax
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| state | dict | State tree containing AbstractAxis instances and fixed values. | required |
| mode | str | Combination mode: ‘product’ for Cartesian product, ‘zip’ for parallel. Default is ‘zip’. | 'zip' |
| key | jax.random.PRNGKey | Random key for stochastic axes. Default creates new key. | None |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If mode is not ‘product’ or ‘zip’, or if no AbstractAxis instances found. |
Examples
>>> import jax.numpy as jnp
>>> import copy
>>> from tvboptim.types.spaces import Space, GridAxis, UniformAxis
>>>
>>> # Create exploration state
>>> exploration_state = copy.deepcopy(base_state)
>>> exploration_state.parameters.coupling.a = GridAxis(0.0, 1.0, 5)
>>> exploration_state.parameters.model.J_N = UniformAxis(0.0, 1.0, 3)
>>>
>>> # Product mode: 5 × 3 = 15 combinations
>>> space = Space(exploration_state, mode='product')
>>> print(f"Total combinations: {space.N}")
>>>
>>> # Access individual combinations
>>> first_state = space[0]
>>> print(first_state.parameters.coupling.a)
>>>
>>> # Create subspace
>>> subset = space[2:8]
>>> print(f"Subset size: {len(subset)}")
>>>
>>> # Parallel execution
>>> batched_states = space.collect(n_vmap=4, n_pmap=2)Attributes
| Name | Description |
|---|---|
| N | Total number of parameter combinations. |
Methods
| Name | Description |
|---|---|
| collect | Generate batched states for parallel execution. |
collect
spaces.Space.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan, combine=True)Generate batched states for parallel execution.
Creates parameter combinations organized for efficient JAX parallel execution using vmap and pmap. This method is essential for high-performance parameter exploration on modern accelerators.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| n_vmap | int | Number of states to vectorize over using vmap. If None, defaults to 1. | None |
| n_pmap | int | Number of devices for parallel mapping with pmap. If None, defaults to 1. | None |
| fill_value | float | Value used for padding when reshaping to target dimensions. Default is jnp.nan. | jnp.nan |
| combine | bool | Whether to combine the static state with the parameter combinations into a complete state. If False, returns (axis_tree, static_state) tuple. Default is True. | True |
Returns
| Name | Type | Description |
|---|---|---|
| dict or tuple | If combine=True: State tree with batched parameter combinations shaped for parallel execution with dimensions (n_pmap, n_vmap, n_map). If combine=False: Tuple of (batched_axis_tree, static_state). |
Warnings
If the total requested size (n_pmap * n_vmap * n_map) exceeds the number of combinations N, padding with fill_value will be used.
Examples
>>> # Create space with 100 combinations
>>> space = Space(exploration_state, mode='product')
>>>
>>> # Batch for parallel execution: 2 devices, 8 vectors each
>>> batched_states = space.collect(n_vmap=8, n_pmap=2)
>>> print(batched_states.shape) # (2, 8, n_map)
>>>
>>> # Use with JAX transformations
>>> results = jax.pmap(jax.vmap(simulation_fn))(batched_states)