parameter.SigmoidBoundedParameter

parameter.SigmoidBoundedParameter(value, low, high)

Parameter with sigmoid-based bounds enforcement.

Uses sigmoid transformation to map from unconstrained real space to bounded interval [low, high]. This provides smooth gradients and avoids the gradient issues of hard clipping.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The initial parameter value (will be clipped to bounds). required
low float Lower bound for the parameter value. required
high float Upper bound for the parameter value. required

Examples

>>> param = SigmoidBoundedParameter(0.7, low=0.0, high=1.0)
>>> param.__jax_array__()  # Returns value in [0, 1]
>>> # Gradients flow smoothly through sigmoid transformation

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.SigmoidBoundedParameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.SigmoidBoundedParameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.