parameter.Parameter

parameter.Parameter(value)

A minimal JAX-native parameter with full arithmetic support.

If placed at a position in the state tree, that position will take part in optimizations and will be differentiated. A state can be split into parameters and static parts by partition_state and combined by combine_state.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The parameter value. Will be converted to JAX array. required

Examples

>>> p1 = Parameter(1.0)
>>> p2 = Parameter(jnp.array([1, 2, 3]))
>>> result = p1 + p2  # Works with JAX arithmetic
>>> grad_fn = jax.grad(lambda p: jnp.sum(p**2))
>>> gradients = grad_fn(p1)  # Seamless gradients

Attributes

Name Description
dtype Data type of the parameter value.
ndim Number of dimensions.
shape Shape of the parameter value.
size Total number of elements.

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.Parameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.Parameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.