parameter.BoundedParameter
parameter.BoundedParameter(value, low, high)Parameter with automatic bounds enforcement.
The bounds are applied transparently whenever the parameter is used as a JAX array, ensuring constraints are always satisfied.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The parameter value. Will be converted to JAX array. | required |
| low | float | Lower bound for the parameter value. | required |
| high | float | Upper bound for the parameter value. | required |
Examples
>>> param = BoundedParameter(1.5, low=0.0, high=1.0)
>>> result = param + 0.1 # Automatically clips to [0, 1]
>>> print(result) # 1.1Methods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.BoundedParameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.BoundedParameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.