spaces.UniformSpace
=1, key=jax.random.key(0)) spaces.UniformSpace(state, N
A Space for uniform random sampling over parameter bounds.
UniformSpace generates random parameter samples uniformly distributed within specified bounds for each free parameter. Each parameter must have defined low and high bounds. The space can generate N samples and provides iteration and indexing capabilities for accessing individual parameter configurations.
Parameters
Name | Type | Description | Default |
---|---|---|---|
state | dict or PyTree | The parameter state template containing Value objects with defined bounds. All free parameters must have both low and high bounds specified. | required |
N | int | Number of random samples to generate. Default is 1. | 1 |
key | jax.random.PRNGKey | Random key for reproducible sampling. Default is jax.random.key(0). | jax.random.key(0) |
Attributes
Name | Type | Description |
---|---|---|
state | dict or PyTree | The original parameter state template. |
N | int | Number of samples to generate. |
key | jax.random.PRNGKey | Random key used for sampling. |
Raises
Name | Type | Description |
---|---|---|
ValueError | If any free parameter in diff_state lacks defined low or high bounds. |
Examples
>>> # Define parameter space with bounds
>>> state = {
'param1': Parameter("param1", 0.0, low=0.0, high=1.0, free = True),
... 'param2': Parameter("param1", 0.0, low=-2.0, high=2.0, free = True),
... 'static_param': 5.0
...
... }>>>
>>> # Create uniform space with 100 samples
>>> key = jax.random.key(42)
>>> space = UniformSpace(state, N=100, key=key)
>>>
>>> # Iterate over samples
>>> for sample_state in space:
= simulate(sample_state)
... result >>>
>>> # Get specific sample
>>> first_sample = space[0]
>>>
>>> # Collect for parallel execution
>>> batched_state = space.collect(n_vmap=10, n_pmap=2)
Methods
Name | Description |
---|---|
collect | Generate and reshape grid points for efficient parallel execution. |
collect
=None, n_pmap=None, fill_value=jnp.nan) spaces.UniformSpace.collect(n_vmap
Generate and reshape grid points for efficient parallel execution.
Creates the full parameter grid and organizes it into a structure optimized for JAX’s vectorization and parallelization primitives.
Parameters
Name | Type | Description | Default |
---|---|---|---|
n_vmap | int | Number of states to vectorize over using vmap. If None, defaults to 1. | None |
n_pmap | int | Number of devices for parallel mapping with pmap. If None, defaults to 1. | None |
fill_value | float | Value used to pad arrays when total requested size exceeds N. Default is jnp.nan. | jnp.nan |
Returns
Name | Type | Description |
---|---|---|
PyTree | Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, …) where n_map is computed to accommodate all N grid points across the parallel execution strategy. |