spaces.UniformSpace

spaces.UniformSpace(state, N=1, key=jax.random.key(0))

A Space for uniform random sampling over parameter bounds.

UniformSpace generates random parameter samples uniformly distributed within specified bounds for each free parameter. Each parameter must have defined low and high bounds. The space can generate N samples and provides iteration and indexing capabilities for accessing individual parameter configurations.

Parameters

Name Type Description Default
state dict or PyTree The parameter state template containing Value objects with defined bounds. All free parameters must have both low and high bounds specified. required
N int Number of random samples to generate. Default is 1. 1
key jax.random.PRNGKey Random key for reproducible sampling. Default is jax.random.key(0). jax.random.key(0)

Attributes

Name Type Description
state dict or PyTree The original parameter state template.
N int Number of samples to generate.
key jax.random.PRNGKey Random key used for sampling.

Raises

Name Type Description
ValueError If any free parameter in diff_state lacks defined low or high bounds.

Examples

>>> # Define parameter space with bounds
>>> state = {
...     'param1': Parameter("param1", 0.0, low=0.0, high=1.0, free = True),
...     'param2': Parameter("param1", 0.0, low=-2.0, high=2.0, free = True),
...     'static_param': 5.0
... }
>>> 
>>> # Create uniform space with 100 samples
>>> key = jax.random.key(42)
>>> space = UniformSpace(state, N=100, key=key)
>>> 
>>> # Iterate over samples
>>> for sample_state in space:
...     result = simulate(sample_state)
>>> 
>>> # Get specific sample
>>> first_sample = space[0]
>>> 
>>> # Collect for parallel execution
>>> batched_state = space.collect(n_vmap=10, n_pmap=2)

Methods

Name Description
collect Generate and reshape grid points for efficient parallel execution.

collect

spaces.UniformSpace.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan)

Generate and reshape grid points for efficient parallel execution.

Creates the full parameter grid and organizes it into a structure optimized for JAX’s vectorization and parallelization primitives.

Parameters

Name Type Description Default
n_vmap int Number of states to vectorize over using vmap. If None, defaults to 1. None
n_pmap int Number of devices for parallel mapping with pmap. If None, defaults to 1. None
fill_value float Value used to pad arrays when total requested size exceeds N. Default is jnp.nan. jnp.nan

Returns

Name Type Description
PyTree Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, …) where n_map is computed to accommodate all N grid points across the parallel execution strategy.