spaces.DataAxis

spaces.DataAxis(values)

Axis for sampling from predefined data values.

Uses a fixed set of values provided by the user. This axis type is deterministic and always returns the same predefined values. The first dimension of the data will always be used as axis dimension.

Parameters

Name Type Description Default
values array - like Predefined values to sample from. required

Raises

Name Type Description
ValueError If values array is empty.

Examples

>>> import jax.numpy as jnp
>>> data = DataAxis([1.0, 2.5, 3.7, 4.2])
>>> values = data.generate_values()
>>> print(values)  # [1.0, 2.5, 3.7, 4.2]
>>>
>>> # Can also use JAX arrays
>>> data_jax = DataAxis(jnp.linspace(0, 1, 5))
>>> print(data_jax.size)  # 5

Attributes

Name Description
size Number of predefined values.

Methods

Name Description
generate_values Return the predefined data values.

generate_values

spaces.DataAxis.generate_values(key=None)

Return the predefined data values.

Parameters

Name Type Description Default
key jax.random.PRNGKey Random key. Ignored for deterministic data sampling. None

Returns

Name Type Description
jnp.ndarray The predefined data values.