optim.OptaxOptimizer

optim.OptaxOptimizer(loss, optimizer, callback=None, has_aux=False)

JAX-based parameter optimization using Optax optimizers with automatic differentiation.

OptaxOptimizer provides a high-level interface for optimizing model parameters using any Optax optimizer (Adam, SGD, RMSprop, etc.). It automatically handles parameter partitioning, gradient computation, and state management while supporting both forward-mode and reverse-mode automatic differentiation.

Parameters

Name Type Description Default
loss callable Loss function to minimize. Should accept a state parameter and return a scalar loss value. Signature: loss(state) -> scalar or (scalar, aux_data) if has_aux=True. required
optimizer optax.GradientTransformation Optax optimizer instance (e.g., optax.adam(0.001), optax.sgd(0.01)). Defines the optimization algorithm and hyperparameters. required
callback callable Optional callback function called after each optimization step. Signature: callback(step, diff_state, static_state, fitting_data, aux_data, loss_value, grads) -> (stop_flag, new_diff_state, new_static_state). Default is None, see the callbacks module for many useful callbacks. None
has_aux bool Whether the loss function returns auxiliary data along with the loss value. If True, loss should return (loss_value, aux_data). Default is False. False

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> import optax
>>> from tvboptim.types.parameter import Parameter
>>> 
>>> # Define loss function
>>> def mse_loss(state):
...     prediction = state['weight'] * state['input'] + state['bias']
...     target = 2.5
...     return jnp.mean((prediction - target) ** 2)
>>> 
>>> # Define parameter state
>>> state = {
...     'weight': Parameter("weight", 1.0, free=True),
...     'bias': Parameter("bias", 0.0, free=True),
...     'input': 1.5  # Static parameter
... }
>>> 
>>> # Create optimizer
>>> opt = OptaxOptimizer(
...     loss=mse_loss,
...     optimizer=optax.adam(learning_rate=0.01)
... )
>>> 
>>> # Run optimization
>>> final_state, history = opt.run(state, max_steps=1000)
>>> print(f"Optimized weight: {final_state['weight']}")
>>> 
>>> # With auxiliary data and callback
>>> def loss_with_aux(state):
...     pred = state['weight'] * state['input'] + state['bias']
...     loss = jnp.mean((pred - 2.5) ** 2)
...     aux = {'prediction': pred, 'error': pred - 2.5}
...     return loss, aux
>>> 
>>> def monitor_callback(step, diff_state, static_state, fitting_data, 
...                      aux_data, loss_value, grads):
...     if step % 100 == 0:
...         print(f"Step {step}: Loss = {loss_value}")
...     # Early stopping condition
...     stop = loss_value < 1e-6
...     return stop, diff_state, static_state
>>> 
>>> opt_aux = OptaxOptimizer(
...     loss=loss_with_aux,
...     optimizer=optax.adam(0.01),
...     callback=monitor_callback,
...     has_aux=True
... )
>>> final_state, history = opt_aux.run(state, max_steps=1000, mode="rev")

Notes

Parameter Partitioning:

The optimizer automatically partitions the state into: - diff_state: Parameters marked as free=True (optimized) - static_state: Parameters marked as free=False (constant)

Only free parameters are optimized, while static parameters remain unchanged throughout the optimization process.

Differentiation Modes:

  • “rev” (default): Reverse-mode AD, efficient for many parameters
  • “fwd”: Forward-mode AD, efficient for few parameters or when gradients are needed w.r.t. many outputs

Methods

Name Description
run Execute parameter optimization for the specified number of steps.

run

optim.OptaxOptimizer.run(state, max_steps=1, mode='rev')

Execute parameter optimization for the specified number of steps.

Performs gradient-based optimization of free parameters in the state using the configured Optax optimizer. Automatically handles parameter partitioning, gradient computation, and state updates.

Parameters

Name Type Description Default
state PyTree Initial parameter state containing both free and static parameters. Free parameters (marked with free=True) will be optimized. required
max_steps int Maximum number of optimization steps to perform. Default is 1. 1
mode (rev, fwd) Automatic differentiation mode. “rev” for reverse-mode (default), “fwd” for forward-mode. Default is “rev”. "rev"

Returns

Name Type Description
tuple A tuple containing: - final_state (PyTree): Optimized parameter state with updated free parameters and unchanged static parameters. - fitting_data (dict): Dictionary containing optimization history and metadata collected during the optimization process.

Notes

Gradient Computation:

The method automatically selects appropriate gradient computation based on the mode parameter and loss function characteristics. Reverse-mode is typically preferred for parameter optimization scenarios.