parameter.MaskedParameter

parameter.MaskedParameter(value, mask)

Parameter that keeps masked entries fixed at their initial values.

This allows selective optimization where only certain entries in an array are subject to optimization while others remain frozen. Useful for maintaining structural constraints like sparsity patterns, symmetries, or fixed values.

Parameters

Name Type Description Default
value Union[float, int, jnp.ndarray] The initial parameter values. required
mask jnp.ndarray Boolean mask where True indicates optimizable entries and False indicates frozen entries that maintain their initial values. required

Examples

>>> # Keep zero entries frozen, optimize non-zero entries
>>> value = jnp.array([1.0, 0.0, 3.0, 0.0, 5.0])
>>> mask = value != 0.0  # True for non-zero, False for zero
>>> param = MaskedParameter(value, mask)
>>> # Only positions [0, 2, 4] will be optimized
>>> # Upper triangular matrix optimization
>>> matrix = jnp.array([[1, 2], [0, 3]])
>>> mask = jnp.triu(jnp.ones_like(matrix)).astype(bool)
>>> param = MaskedParameter(matrix, mask)

Methods

Name Description
tree_flatten Flatten for JAX pytree registration.
tree_unflatten Unflatten for JAX pytree registration.

tree_flatten

parameter.MaskedParameter.tree_flatten()

Flatten for JAX pytree registration.

tree_unflatten

parameter.MaskedParameter.tree_unflatten(aux_data, children)

Unflatten for JAX pytree registration.