parameter.NormalizedParameter
parameter.NormalizedParameter(value)Parameter that stores normalized values (ones) internally but presents scaled values (scale * ones) to the outside world.
This enables optimization with normalized coordinates where gradients have consistent magnitudes across different parameter scales, while still returning properly scaled values for computation.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The original parameter value used to compute the static scale. | required |
Examples
>>> param = NormalizedParameter(jnp.array([2.0, 4.0, 6.0]))
>>> param.value # Internal normalized storage (ones)
Array([1., 1., 1.], dtype=float32)
>>> param.__jax_array__() # External scaled values (scale * ones)
Array([2., 4., 6.], dtype=float32)
>>> param.scale # Static scale factor
Array([2., 4., 6.], dtype=float32)Methods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.NormalizedParameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.NormalizedParameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.