parameter.BoundedParameter

parameter.BoundedParameter(value, low, high)

Parameter with automatic bounds enforcement.

The bounds are applied transparently whenever the parameter is used as a JAX array, ensuring constraints are always satisfied.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The parameter value. Will be converted to JAX array. required
low float Lower bound for the parameter value. required
high float Upper bound for the parameter value. required

Examples

>>> param = BoundedParameter(1.5, low=0.0, high=1.0)
>>> result = param + 0.1  # Automatically clips to [0, 1]
>>> print(result)  # 1.1

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.BoundedParameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.BoundedParameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.