execution.ParallelExecution

execution.ParallelExecution(
    model,
    statespace,
    *args,
    n_vmap=1,
    n_pmap=1,
    **kwargs,
)

Efficient parallel execution of models across parameter spaces using JAX.

ParallelExecution orchestrates the parallel computation of a model function across all parameter combinations in a given state space. It leverages JAX’s pmap (for multi-device parallelism) and vmap (for vectorization).

Parameters

Name Type Description Default
model callable Model function to execute. Should accept a state parameter and return simulation results. Signature: model(state, *args, **kwargs). required
statespace AbstractSpace Parameter space (DataSpace, UniformSpace, or GridSpace) defining the parameter combinations to execute across. required
*args tuple Positional arguments passed directly to the model function. ()
n_vmap int Number of states to vectorize over using jax.vmap. Controls batch size for vectorized execution within each device. Default is 1. 1
n_pmap int Number of devices to parallelize over using jax.pmap. Should typically match the number of available devices. Default is 1. 1
**kwargs dict Keyword arguments passed directly to the model function. {}

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from tvboptim.types.spaces import GridSpace
>>> from tvboptim.types.parameter import Parameter
>>> 
>>> # Define a simple model
>>> def simulate(state):
...     return state['param1'] * state['param2']
>>> 
>>> # Create parameter space
>>> state = {
...     'param1': Parameter("param1", 0.0, low=0.0, high=1.0, free=True),
...     'param2': Parameter("param2", 0.0, low=-1.0, high=1.0, free=True)
... }
>>> space = GridSpace(state, n=10)  # 100 parameter combinations
>>> 
>>> # Set up parallel execution
>>> n_devices = jax.device_count()
>>> executor = ParallelExecution(
...     model=simulate,
...     statespace=space,
...     n_vmap=5,           # Vectorize over 5 states per device
...     n_pmap=n_devices    # Use all available devices
... )
>>> 
>>> # Execute across all parameter combinations
>>> results = executor.run()
>>> 
>>> # Access individual results
>>> first_result = results[0]
>>> all_results = list(results)  # Convert to list
>>> subset_results = results[10:20]  # Slice notation

Notes

For optimal performance:

  • Set n_pmap to match the number of available devices - on CPU use the pmap trick to force XLA to use N devices: os.environ[‘XLA_FLAGS’] = f’–xla_force_host_platform_device_count={N}’
  • Tune n_vmap based on memory constraints and model complexity

The execution uses jax.block_until_ready() to ensure all computation completes before returning results, providing accurate timing measurements.

Methods

Name Description
run Execute the model across all parameter combinations in parallel.

run

execution.ParallelExecution.run()

Execute the model across all parameter combinations in parallel.