execution.ParallelExecution
execution.ParallelExecution(
model,
statespace,*args,
=1,
n_vmap=1,
n_pmap**kwargs,
)
Efficient parallel execution of models across parameter spaces using JAX.
ParallelExecution orchestrates the parallel computation of a model function across all parameter combinations in a given state space. It leverages JAX’s pmap (for multi-device parallelism) and vmap (for vectorization).
Parameters
Name | Type | Description | Default |
---|---|---|---|
model | callable | Model function to execute. Should accept a state parameter and return simulation results. Signature: model(state, *args, **kwargs). | required |
statespace | AbstractSpace | Parameter space (DataSpace, UniformSpace, or GridSpace) defining the parameter combinations to execute across. | required |
*args | tuple | Positional arguments passed directly to the model function. | () |
n_vmap | int | Number of states to vectorize over using jax.vmap. Controls batch size for vectorized execution within each device. Default is 1. | 1 |
n_pmap | int | Number of devices to parallelize over using jax.pmap. Should typically match the number of available devices. Default is 1. | 1 |
**kwargs | dict | Keyword arguments passed directly to the model function. | {} |
Examples
>>> import jax
>>> import jax.numpy as jnp
>>> from tvboptim.types.spaces import GridSpace
>>> from tvboptim.types.parameter import Parameter
>>>
>>> # Define a simple model
>>> def simulate(state):
return state['param1'] * state['param2']
... >>>
>>> # Create parameter space
>>> state = {
'param1': Parameter("param1", 0.0, low=0.0, high=1.0, free=True),
... 'param2': Parameter("param2", 0.0, low=-1.0, high=1.0, free=True)
...
... }>>> space = GridSpace(state, n=10) # 100 parameter combinations
>>>
>>> # Set up parallel execution
>>> n_devices = jax.device_count()
>>> executor = ParallelExecution(
=simulate,
... model=space,
... statespace=5, # Vectorize over 5 states per device
... n_vmap=n_devices # Use all available devices
... n_pmap
... )>>>
>>> # Execute across all parameter combinations
>>> results = executor.run()
>>>
>>> # Access individual results
>>> first_result = results[0]
>>> all_results = list(results) # Convert to list
>>> subset_results = results[10:20] # Slice notation
Notes
For optimal performance:
- Set n_pmap to match the number of available devices - on CPU use the pmap trick to force XLA to use N devices: os.environ[‘XLA_FLAGS’] = f’–xla_force_host_platform_device_count={N}’
- Tune n_vmap based on memory constraints and model complexity
The execution uses jax.block_until_ready() to ensure all computation completes before returning results, providing accurate timing measurements.
Methods
Name | Description |
---|---|
run | Execute the model across all parameter combinations in parallel. |
run
execution.ParallelExecution.run()
Execute the model across all parameter combinations in parallel.