parameter.TransformedParameter
parameter.TransformedParameter(value, forward_transform, inverse_transform)Parameter with custom forward and reverse transformations.
This enables arbitrary parameter transformations by applying: - inverse transform: constrained → unconstrained (at initialization) - forward transform: unconstrained → constrained (when used as array)
The parameter is stored internally in unconstrained space for smooth optimization, but presents constrained values when used.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The initial constrained parameter value. | required |
| forward_transform | callable | Function to transform from unconstrained to constrained space. Applied every time the parameter is used as a JAX array. | required |
| inverse_transform | callable | Function to transform from constrained to unconstrained space. Applied once at initialization. | required |
Examples
>>> # Log-normal parameter
>>> forward = lambda x: jnp.exp(x) # unconstrained → constrained
>>> inverse = lambda x: jnp.log(x) # constrained → unconstrained
>>> param = TransformedParameter(2.0, forward, inverse)
>>> param.__jax_array__() # Returns exp(log(2.0)) = 2.0>>> # Sigmoid bounded parameter
>>> forward = lambda x: jax.nn.sigmoid(x) # unconstrained → constrained
>>> inverse = lambda x: jnp.log(x / (1 - x)) # constrained → unconstrained
>>> param = TransformedParameter(0.7, forward, inverse)Methods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.TransformedParameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.TransformedParameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.