parameter.MaskedParameter
parameter.MaskedParameter(value, mask)Parameter that keeps masked entries fixed at their initial values.
This allows selective optimization where only certain entries in an array are subject to optimization while others remain frozen. Useful for maintaining structural constraints like sparsity patterns, symmetries, or fixed values.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| value | Union[float, int, jnp.ndarray] | The initial parameter values. | required |
| mask | jnp.ndarray | Boolean mask where True indicates optimizable entries and False indicates frozen entries that maintain their initial values. | required |
Examples
>>> # Keep zero entries frozen, optimize non-zero entries
>>> value = jnp.array([1.0, 0.0, 3.0, 0.0, 5.0])
>>> mask = value != 0.0 # True for non-zero, False for zero
>>> param = MaskedParameter(value, mask)
>>> # Only positions [0, 2, 4] will be optimized>>> # Upper triangular matrix optimization
>>> matrix = jnp.array([[1, 2], [0, 3]])
>>> mask = jnp.triu(jnp.ones_like(matrix)).astype(bool)
>>> param = MaskedParameter(matrix, mask)Methods
| Name | Description |
|---|---|
| tree_flatten | Flatten for JAX pytree registration. |
| tree_unflatten | Unflatten for JAX pytree registration. |
tree_flatten
parameter.MaskedParameter.tree_flatten()Flatten for JAX pytree registration.
tree_unflatten
parameter.MaskedParameter.tree_unflatten(aux_data, children)Unflatten for JAX pytree registration.