types.stateutils
types.stateutils
Functions
| Name | Description |
|---|---|
| collect_parameters | Extract values from Parameter objects in a state tree. |
| combine_state | Recombine optimized parameters with static values. |
| mark_parameters | Mark Parameter objects for partitioning. |
| partition_state | Separate Parameter objects from static values for optimization. |
| show_parameters | Show Parameter objects in the tree. |
collect_parameters
types.stateutils.collect_parameters(state)Extract values from Parameter objects in a state tree.
This function traverses a JAX PyTree state and extracts the underlying values from Parameter objects while leaving other values unchanged.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| state | Any | JAX PyTree containing Parameter objects and other values. | required |
Returns
| Name | Type | Description |
|---|---|---|
| Any | JAX PyTree with same structure as input, but Parameter objects replaced by their underlying values. |
Examples
>>> from tvboptim.types import Parameter
>>> import jax.numpy as jnp
>>>
>>> # Create state with Parameter objects
>>> state = {
... 'param1': Parameter(jnp.array(1.5)),
... 'param2': jnp.array(2.0),
... 'nested': {'param3': Parameter(jnp.array([1, 2, 3]))}
... }
>>>
>>> # Extract values
>>> values = collect_parameters(state)
>>> print(values['param1']) # jnp.array(1.5)
>>> print(values['param2']) # jnp.array(2.0)Notes
This function is useful when you need to extract raw JAX arrays from a state tree for operations that don’t require Parameter metadata. With the new Parameter system, this function may become less necessary as Parameters support the JAX array protocol directly.
combine_state
types.stateutils.combine_state(diff_state, static_state)Recombine optimized parameters with static values.
mark_parameters
types.stateutils.mark_parameters(state)Mark Parameter objects for partitioning.
partition_state
types.stateutils.partition_state(state)Separate Parameter objects from static values for optimization.
show_parameters
types.stateutils.show_parameters(tree)Show Parameter objects in the tree.