spaces.DataSpace

spaces.DataSpace(state)

A Space of data for parallel execution over parameter sets.

DataSpace manages collections of parameter states where the first dimension represents the data dimension. The Executor performs operations along this dimension, enabling efficient parallel computation across multiple parameter configurations.

Parameters

Name Type Description Default
state PyTree The initial state containing parameter data. All free parameters must have the same size in the first dimension (data dimension). required

Attributes

Name Type Description
state PyTree The original input state.
N int Number of data points (size of first dimension).

Raises

Name Type Description
AssertionError If free parameters don’t have the same number of data points in the first dimension.

Examples

>>> import jax.numpy as jnp
>>> state = {'param1': jnp.array([[1, 2], [3, 4], [5, 6]]),
...          'param2': jnp.array([0.1, 0.2, 0.3])}
>>> ds = DataSpace(state)
>>> ds.N
3
>>> single_state = ds[0]  # Get first parameter set
>>> for state in ds:     # Iterate over all parameter sets
...     model(state)

Methods

Name Description
collect Generate and reshape grid points for efficient parallel execution.

collect

spaces.DataSpace.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan)

Generate and reshape grid points for efficient parallel execution.

Creates the full parameter grid and organizes it into a structure optimized for JAX’s vectorization and parallelization primitives.

Parameters

Name Type Description Default
n_vmap int Number of states to vectorize over using vmap. If None, defaults to 1. None
n_pmap int Number of devices for parallel mapping with pmap. If None, defaults to 1. None
fill_value float Value used to pad arrays when total requested size exceeds N. Default is jnp.nan. jnp.nan

Returns

Name Type Description
PyTree Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, …) where n_map is computed to accommodate all N grid points across the parallel execution strategy.