parameter.TransformedParameter

parameter.TransformedParameter(value, forward_transform, inverse_transform)

Parameter with custom forward and reverse transformations.

This enables arbitrary parameter transformations by applying: - inverse transform: constrained → unconstrained (at initialization) - forward transform: unconstrained → constrained (when used as array)

The parameter is stored internally in unconstrained space for smooth optimization, but presents constrained values when used.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The initial constrained parameter value. required
forward_transform callable Function to transform from unconstrained to constrained space. Applied every time the parameter is used as a JAX array. required
inverse_transform callable Function to transform from constrained to unconstrained space. Applied once at initialization. required

Examples

>>> # Log-normal parameter
>>> forward = lambda x: jnp.exp(x)  # unconstrained → constrained
>>> inverse = lambda x: jnp.log(x)  # constrained → unconstrained
>>> param = TransformedParameter(2.0, forward, inverse)
>>> param.__jax_array__()  # Returns exp(log(2.0)) = 2.0
>>> # Sigmoid bounded parameter
>>> forward = lambda x: jax.nn.sigmoid(x)  # unconstrained → constrained
>>> inverse = lambda x: jnp.log(x / (1 - x))  # constrained → unconstrained
>>> param = TransformedParameter(0.7, forward, inverse)

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.TransformedParameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.TransformedParameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.