parameter.NormalizedParameter

parameter.NormalizedParameter(value)

Parameter that stores normalized values (ones) internally but presents scaled values (scale * ones) to the outside world.

This enables optimization with normalized coordinates where gradients have consistent magnitudes across different parameter scales, while still returning properly scaled values for computation.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The original parameter value used to compute the static scale. required

Examples

>>> param = NormalizedParameter(jnp.array([2.0, 4.0, 6.0]))
>>> param.value  # Internal normalized storage (ones)
Array([1., 1., 1.], dtype=float32)
>>> param.__jax_array__()  # External scaled values (scale * ones)
Array([2., 4., 6.], dtype=float32)
>>> param.scale  # Static scale factor
Array([2., 4., 6.], dtype=float32)

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.NormalizedParameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.NormalizedParameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.