optim.OptaxOptimizer
=None, has_aux=False) optim.OptaxOptimizer(loss, optimizer, callback
JAX-based parameter optimization using Optax optimizers with automatic differentiation.
OptaxOptimizer provides a high-level interface for optimizing model parameters using any Optax optimizer (Adam, SGD, RMSprop, etc.). It automatically handles parameter partitioning, gradient computation, and state management while supporting both forward-mode and reverse-mode automatic differentiation.
Parameters
Name | Type | Description | Default |
---|---|---|---|
loss | callable | Loss function to minimize. Should accept a state parameter and return a scalar loss value. Signature: loss(state) -> scalar or (scalar, aux_data) if has_aux=True. | required |
optimizer | optax.GradientTransformation | Optax optimizer instance (e.g., optax.adam(0.001), optax.sgd(0.01)). Defines the optimization algorithm and hyperparameters. | required |
callback | callable | Optional callback function called after each optimization step. Signature: callback(step, diff_state, static_state, fitting_data, aux_data, loss_value, grads) -> (stop_flag, new_diff_state, new_static_state). Default is None, see the callbacks module for many useful callbacks. | None |
has_aux | bool | Whether the loss function returns auxiliary data along with the loss value. If True, loss should return (loss_value, aux_data). Default is False. | False |
Examples
>>> import jax
>>> import jax.numpy as jnp
>>> import optax
>>> from tvboptim.types.parameter import Parameter
>>>
>>> # Define loss function
>>> def mse_loss(state):
= state['weight'] * state['input'] + state['bias']
... prediction = 2.5
... target return jnp.mean((prediction - target) ** 2)
... >>>
>>> # Define parameter state
>>> state = {
'weight': Parameter("weight", 1.0, free=True),
... 'bias': Parameter("bias", 0.0, free=True),
... 'input': 1.5 # Static parameter
...
... }>>>
>>> # Create optimizer
>>> opt = OptaxOptimizer(
=mse_loss,
... loss=optax.adam(learning_rate=0.01)
... optimizer
... )>>>
>>> # Run optimization
>>> final_state, history = opt.run(state, max_steps=1000)
>>> print(f"Optimized weight: {final_state['weight']}")
>>>
>>> # With auxiliary data and callback
>>> def loss_with_aux(state):
= state['weight'] * state['input'] + state['bias']
... pred = jnp.mean((pred - 2.5) ** 2)
... loss = {'prediction': pred, 'error': pred - 2.5}
... aux return loss, aux
... >>>
>>> def monitor_callback(step, diff_state, static_state, fitting_data,
... aux_data, loss_value, grads):if step % 100 == 0:
... print(f"Step {step}: Loss = {loss_value}")
... # Early stopping condition
... = loss_value < 1e-6
... stop return stop, diff_state, static_state
... >>>
>>> opt_aux = OptaxOptimizer(
=loss_with_aux,
... loss=optax.adam(0.01),
... optimizer=monitor_callback,
... callback=True
... has_aux
... )>>> final_state, history = opt_aux.run(state, max_steps=1000, mode="rev")
Notes
Parameter Partitioning:
The optimizer automatically partitions the state into: - diff_state: Parameters marked as free=True (optimized) - static_state: Parameters marked as free=False (constant)
Only free parameters are optimized, while static parameters remain unchanged throughout the optimization process.
Differentiation Modes:
- “rev” (default): Reverse-mode AD, efficient for many parameters
- “fwd”: Forward-mode AD, efficient for few parameters or when gradients are needed w.r.t. many outputs
Methods
Name | Description |
---|---|
run | Execute parameter optimization for the specified number of steps. |
run
=1, mode='rev') optim.OptaxOptimizer.run(state, max_steps
Execute parameter optimization for the specified number of steps.
Performs gradient-based optimization of free parameters in the state using the configured Optax optimizer. Automatically handles parameter partitioning, gradient computation, and state updates.
Parameters
Name | Type | Description | Default |
---|---|---|---|
state | PyTree | Initial parameter state containing both free and static parameters. Free parameters (marked with free=True) will be optimized. | required |
max_steps | int | Maximum number of optimization steps to perform. Default is 1. | 1 |
mode | (rev, fwd) | Automatic differentiation mode. “rev” for reverse-mode (default), “fwd” for forward-mode. Default is “rev”. | "rev" |
Returns
Name | Type | Description |
---|---|---|
tuple | A tuple containing: - final_state (PyTree): Optimized parameter state with updated free parameters and unchanged static parameters. - fitting_data (dict): Dictionary containing optimization history and metadata collected during the optimization process. |
Notes
Gradient Computation:
The method automatically selects appropriate gradient computation based on the mode parameter and loss function characteristics. Reverse-mode is typically preferred for parameter optimization scenarios.