types.stateutils

types.stateutils

Functions

Name Description
collect_parameters Extract values from Parameter objects in a state tree.
combine_state Recombine optimized parameters with static values.
mark_parameters Mark Parameter objects for partitioning.
partition_state Separate Parameter objects from static values for optimization.
show_parameters Show Parameter objects in the tree.

collect_parameters

types.stateutils.collect_parameters(state)

Extract values from Parameter objects in a state tree.

This function traverses a JAX PyTree state and extracts the underlying values from Parameter objects while leaving other values unchanged.

Parameters

Name Type Description Default
state Any JAX PyTree containing Parameter objects and other values. required

Returns

Name Type Description
Any JAX PyTree with same structure as input, but Parameter objects replaced by their underlying values.

Examples

>>> from tvboptim.types import Parameter
>>> import jax.numpy as jnp
>>>
>>> # Create state with Parameter objects
>>> state = {
...     'param1': Parameter(jnp.array(1.5)),
...     'param2': jnp.array(2.0),
...     'nested': {'param3': Parameter(jnp.array([1, 2, 3]))}
... }
>>>
>>> # Extract values
>>> values = collect_parameters(state)
>>> print(values['param1'])  # jnp.array(1.5)
>>> print(values['param2'])  # jnp.array(2.0)

Notes

This function is useful when you need to extract raw JAX arrays from a state tree for operations that don’t require Parameter metadata. With the new Parameter system, this function may become less necessary as Parameters support the JAX array protocol directly.

combine_state

types.stateutils.combine_state(diff_state, static_state)

Recombine optimized parameters with static values.

mark_parameters

types.stateutils.mark_parameters(state)

Mark Parameter objects for partitioning.

partition_state

types.stateutils.partition_state(state)

Separate Parameter objects from static values for optimization.

show_parameters

types.stateutils.show_parameters(tree)

Show Parameter objects in the tree.