A Space of data for parallel execution over parameter sets.
DataSpace manages collections of parameter states where the first dimension represents the data dimension. The Executor performs operations along this dimension, enabling efficient parallel computation across multiple parameter configurations.
If free parameters don’t have the same number of data points in the first dimension.
Examples
>>>import jax.numpy as jnp>>> state = {'param1': jnp.array([[1, 2], [3, 4], [5, 6]]),... 'param2': jnp.array([0.1, 0.2, 0.3])}>>> ds = DataSpace(state)>>> ds.N3>>> single_state = ds[0] # Get first parameter set>>>for state in ds: # Iterate over all parameter sets... model(state)
Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, …) where n_map is computed to accommodate all N grid points across the parallel execution strategy.
Source Code
# spaces.DataSpace { #tvboptim.spaces.DataSpace }```pythonspaces.DataSpace(state)```A Space of data for parallel execution over parameter sets.DataSpace manages collections of parameter states where the first dimensionrepresents the data dimension. The Executor performs operations along thisdimension, enabling efficient parallel computation across multiple parameterconfigurations.## Parameters {.doc-section .doc-section-parameters}| Name | Type | Description | Default ||--------|--------------------|-----------------------------------------------------------------------------------------------------------------------------------|------------|| state |[PyTree](`PyTree`)| The initial state containing parameter data. All free parameters must have the same size in the first dimension (data dimension). | _required_ |## Attributes {.doc-section .doc-section-attributes}| Name | Type | Description ||--------|--------------------|--------------------------------------------------|| state |[PyTree](`PyTree`)| The original input state. || N |[int](`int`)| Number of data points (size of first dimension). |## Raises {.doc-section .doc-section-raises}| Name | Type | Description ||--------|------------------------------------|--------------------------------------------------------------------------------------|||[AssertionError](`AssertionError`)| If free parameters don't have the same number of data points in the first dimension. |## Examples {.doc-section .doc-section-examples}```python>>>import jax.numpy as jnp>>> state = {'param1': jnp.array([[1, 2], [3, 4], [5, 6]]),... 'param2': jnp.array([0.1, 0.2, 0.3])}>>> ds = DataSpace(state)>>> ds.N3>>> single_state = ds[0] # Get first parameter set>>>for state in ds: # Iterate over all parameter sets... model(state)```## Methods| Name | Description || ---| ---||[collect](#tvboptim.spaces.DataSpace.collect)| Generate and reshape grid points for efficient parallel execution. |### collect { #tvboptim.spaces.DataSpace.collect }```pythonspaces.DataSpace.collect(n_vmap=None, n_pmap=None, fill_value=jnp.nan)```Generate and reshape grid points for efficient parallel execution.Creates the full parameter grid and organizes it into a structureoptimized for JAX's vectorization and parallelization primitives.#### Parameters {.doc-section .doc-section-parameters}| Name | Type | Description | Default ||------------|------------------|-----------------------------------------------------------------------------------|-----------|| n_vmap |[int](`int`)| Number of states to vectorize over using vmap. If None, defaults to 1. |`None`|| n_pmap |[int](`int`)| Number of devices for parallel mapping with pmap. If None, defaults to 1. |`None`|| fill_value |[float](`float`)| Value used to pad arrays when total requested size exceeds N. Default is jnp.nan. |`jnp.nan`|#### Returns {.doc-section .doc-section-returns}| Name | Type | Description ||--------|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|||[PyTree](`PyTree`)| Reshaped state with grid points organized as: (n_pmap, n_vmap, n_map, ...) where n_map is computed to accommodate all N grid points across the parallel execution strategy. |