parameter.Parameter
parameter.Parameter(value)A minimal JAX-native parameter with full arithmetic support.
If placed at a position in the state tree, that position will take part in optimizations and will be differentiated. A state can be split into parameters and static parts by partition_state and combined by combine_state.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The parameter value. Will be converted to JAX array. | required |
Examples
>>> p1 = Parameter(1.0)
>>> p2 = Parameter(jnp.array([1, 2, 3]))
>>> result = p1 + p2 # Works with JAX arithmetic
>>> grad_fn = jax.grad(lambda p: jnp.sum(p**2))
>>> gradients = grad_fn(p1) # Seamless gradientsAttributes
| Name | Description |
|---|---|
| dtype | Data type of the parameter value. |
| ndim | Number of dimensions. |
| shape | Shape of the parameter value. |
| size | Total number of elements. |
Methods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.Parameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.Parameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.