spaces.UniformAxis
spaces.UniformAxis(low, high, n, shape=None)Axis for uniform random sampling over parameter bounds.
Generates random values uniformly distributed between low and high bounds using stochastic sampling.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| low | float | Lower bound for sampling. | required |
| high | float | Upper bound for sampling. | required |
| n | int | Number of random samples to generate. | required |
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If n <= 0 or low >= high. |
Examples
>>> import jax
>>> uniform = UniformAxis(0.0, 1.0, 3)
>>> values = uniform.generate_values(jax.random.key(42))
>>> print(values) # Random values between 0 and 1Attributes
| Name | Description |
|---|---|
| size | Number of random samples. |
Methods
| Name | Description |
|---|---|
| generate_values | Generate uniformly distributed random values. |
generate_values
spaces.UniformAxis.generate_values(key=None)Generate uniformly distributed random values.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| key | jax.random.PRNGKey | Random key for sampling. If None, uses default key(0). | None |
Returns
| Name | Type | Description |
|---|---|---|
| jnp.ndarray | Array of uniformly distributed random values. If shape is specified, values are broadcast to shape (n,) + shape with identical values across additional dimensions. |