spaces.UniformAxis

spaces.UniformAxis(low, high, n, shape=None)

Axis for uniform random sampling over parameter bounds.

Generates random values uniformly distributed between low and high bounds using stochastic sampling.

Parameters

Name Type Description Default
low float Lower bound for sampling. required
high float Upper bound for sampling. required
n int Number of random samples to generate. required

Raises

Name Type Description
ValueError If n <= 0 or low >= high.

Examples

>>> import jax
>>> uniform = UniformAxis(0.0, 1.0, 3)
>>> values = uniform.generate_values(jax.random.key(42))
>>> print(values)  # Random values between 0 and 1

Attributes

Name Description
size Number of random samples.

Methods

Name Description
generate_values Generate uniformly distributed random values.

generate_values

spaces.UniformAxis.generate_values(key=None)

Generate uniformly distributed random values.

Parameters

Name Type Description Default
key jax.random.PRNGKey Random key for sampling. If None, uses default key(0). None

Returns

Name Type Description
jnp.ndarray Array of uniformly distributed random values. If shape is specified, values are broadcast to shape (n,) + shape with identical values across additional dimensions.