parameter.SigmoidBoundedParameter
parameter.SigmoidBoundedParameter(value, low, high)Parameter with sigmoid-based bounds enforcement.
Uses sigmoid transformation to map from unconstrained real space to bounded interval [low, high]. This provides smooth gradients and avoids the gradient issues of hard clipping.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The initial parameter value (will be clipped to bounds). | required |
| low | float | Lower bound for the parameter value. | required |
| high | float | Upper bound for the parameter value. | required |
Examples
>>> param = SigmoidBoundedParameter(0.7, low=0.0, high=1.0)
>>> param.__jax_array__() # Returns value in [0, 1]
>>> # Gradients flow smoothly through sigmoid transformationMethods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.SigmoidBoundedParameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.SigmoidBoundedParameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.