spaces.GridSpace

spaces.GridSpace(state, n=1)

A Space for systematic grid sampling over parameter bounds.

GridSpace generates a regular grid of parameter values by linearly spacing points between the low and high bounds of each free parameter. The total number of parameter combinations is the product of grid points across all dimensions, enabling systematic exploration of the parameter space.

Parameters

Name Type Description Default
state dict or PyTree The parameter state template containing Parameter objects with defined bounds. All free parameters must have both low and high bounds specified. required
n int Number of grid points along each parameter dimension. Default is 1. After instantiation, n is transformed into a PyTree with n the number of points in that dimension as leave. If n should be different for each dimension, modify the PyTree manually after instantiation. 1

Attributes

Name Type Description
state PyTree The original parameter state template.
n PyTree Number of grid points for each free parameter (currently uniform). Modify the PyTree manually after instantiation for more control.
N int(property) Total number of grid points (product of n across all dimensions).

Raises

Name Type Description
ValueError If any free parameter lacks defined low or high bounds.

Examples

>>> # Define parameter space with bounds
>>> state = {
...     'param1': Parameter("param1", 0.0, low=0.0, high=1.0, free=True),
...     'param2': Parameter("param2", 0.0, low=-2.0, high=2.0, free=True),
...     'static_param': 5.0
... }
>>> 
>>> # Create grid space with 10 points per dimension (100 total combinations)
>>> grid = GridSpace(state, n=10)
>>> print(grid.N)  # 100
>>> # Modify n for specific dimensions
>>> grid.n['param2'] = 5
>>> print(grid.N)  # 50
>>>
>>> # Iterate over all grid points
>>> for sample_state in grid:
...     result = simulate(sample_state)
>>> 
>>> # Get specific grid point
>>> first_sample = grid[0]
>>> corner_sample = grid[-1]  # Last grid point
>>> 
>>> # Get multiple grid points as DataSpace
>>> subset = grid[0:50]  # First 50 grid points
>>> 
>>> # Collect for parallel execution
>>> batched_state = grid.collect(n_vmap=10, n_pmap=5)
>>> 
>>> # Convert to list for external processing
>>> all_states = grid.as_list()

Methods

Name Description
collect Generate and reshape grid points for efficient parallel execution.

collect

spaces.GridSpace.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan)

Generate and reshape grid points for efficient parallel execution.

Creates the full parameter grid and organizes it into a structure optimized for JAX’s vectorization and parallelization primitives.

Parameters

Name Type Description Default
n_vmap int Number of states to vectorize over using vmap. If None, defaults to 1. None
n_pmap int Number of devices for parallel mapping with pmap. If None, defaults to 1. None
fill_value float Value used to pad arrays when total requested size exceeds N. Default is jnp.nan. jnp.nan

Returns

Name Type Description
PyTree Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, …) where n_map is computed to accommodate all N grid points across the parallel execution strategy.