optim.OptaxOptimizer
optim.OptaxOptimizer(loss, optimizer, callback=None, has_aux=False)JAX-based parameter optimization using Optax optimizers with automatic differentiation.
OptaxOptimizer provides a high-level interface for optimizing model parameters using any Optax optimizer (Adam, SGD, RMSprop, etc.). It automatically handles parameter partitioning, gradient computation, and state management while supporting both forward-mode and reverse-mode automatic differentiation.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| loss | callable | Loss function to minimize. Should accept a state parameter and return a scalar loss value. Signature: loss(state) -> scalar or (scalar, aux_data) if has_aux=True. | required |
| optimizer | optax.GradientTransformation | Optax optimizer instance (e.g., optax.adam(0.001), optax.sgd(0.01)). Defines the optimization algorithm and hyperparameters. | required |
| callback | callable | Optional callback function called after each optimization step. Signature: callback(step, diff_state, static_state, fitting_data, aux_data, loss_value, grads) -> (stop_flag, new_diff_state, new_static_state). Default is None, see the callbacks module for many useful callbacks. | None |
| has_aux | bool | Whether the loss function returns auxiliary data along with the loss value. If True, loss should return (loss_value, aux_data). Default is False. | False |
Examples
>>> import jax
>>> import jax.numpy as jnp
>>> import optax
>>> from tvboptim.types.parameter import Parameter
>>>
>>> # Define loss function
>>> def mse_loss(state):
... prediction = state['weight'] * state['input'] + state['bias']
... target = 2.5
... return jnp.mean((prediction - target) ** 2)
>>>
>>> # Define parameter state
>>> state = {
... 'weight': Parameter("weight", 1.0, free=True),
... 'bias': Parameter("bias", 0.0, free=True),
... 'input': 1.5 # Static parameter
... }
>>>
>>> # Create optimizer
>>> opt = OptaxOptimizer(
... loss=mse_loss,
... optimizer=optax.adam(learning_rate=0.01)
... )
>>>
>>> # Run optimization
>>> final_state, history = opt.run(state, max_steps=1000)
>>> print(f"Optimized weight: {final_state['weight']}")
>>>
>>> # With auxiliary data and callback
>>> def loss_with_aux(state):
... pred = state['weight'] * state['input'] + state['bias']
... loss = jnp.mean((pred - 2.5) ** 2)
... aux = {'prediction': pred, 'error': pred - 2.5}
... return loss, aux
>>>
>>> def monitor_callback(step, diff_state, static_state, fitting_data,
... aux_data, loss_value, grads):
... if step % 100 == 0:
... print(f"Step {step}: Loss = {loss_value}")
... # Early stopping condition
... stop = loss_value < 1e-6
... return stop, diff_state, static_state
>>>
>>> opt_aux = OptaxOptimizer(
... loss=loss_with_aux,
... optimizer=optax.adam(0.01),
... callback=monitor_callback,
... has_aux=True
... )
>>> final_state, history = opt_aux.run(state, max_steps=1000, mode="rev")Notes
Parameter Partitioning:
The optimizer automatically partitions the state into: - diff_state: Parameters marked as free=True (optimized) - static_state: Parameters marked as free=False (constant)
Only free parameters are optimized, while static parameters remain unchanged throughout the optimization process.
Differentiation Modes:
- “rev” (default): Reverse-mode AD, efficient for many parameters
- “fwd”: Forward-mode AD, efficient for few parameters or when gradients are needed w.r.t. many outputs
Methods
| Name | Description |
|---|---|
| run | Execute parameter optimization for the specified number of steps. |
run
optim.OptaxOptimizer.run(state, max_steps=1, mode='rev')Execute parameter optimization for the specified number of steps.
Performs gradient-based optimization of free parameters in the state using the configured Optax optimizer. Automatically handles parameter partitioning, gradient computation, and state updates.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| state | PyTree | Initial parameter state containing both free and static parameters. Free parameters (marked with free=True) will be optimized. | required |
| max_steps | int | Maximum number of optimization steps to perform. Default is 1. | 1 |
| mode | (rev, fwd) | Automatic differentiation mode. “rev” for reverse-mode (default), “fwd” for forward-mode. Default is “rev”. | "rev" |
Returns
| Name | Type | Description |
|---|---|---|
| tuple | A tuple containing: - final_state (PyTree): Optimized parameter state with updated free parameters and unchanged static parameters. - fitting_data (dict): Dictionary containing optimization history and metadata collected during the optimization process. |
Notes
Gradient Computation:
The method automatically selects appropriate gradient computation based on the mode parameter and loss function characteristics. Reverse-mode is typically preferred for parameter optimization scenarios.