Loss Function Specification

Mathematical loss functions with symbolic equations, aggregation, and code generation

Overview

TVBO uses SymPy for symbolic mathematics, enabling:

  1. Mathematical notation in YAML specs
  2. LaTeX rendering for documentation
  3. Code generation to JAX/NumPy
  4. Direct execution via class methods

Loading Functions from YAML

Functions are defined in YAML and loaded using the extended Function and LossFunction classes:

from tvbo import Function, LossFunction

Example: Correlation Function

From JR_tvboptim.qmd:

# YAML specification for correlation function
correlation_yaml = """
name: correlation
description: "Pearson correlation coefficient"
equation:
  rhs: "Sum((x[i] - mean(x)) * (y[i] - mean(y)), (i, 0, n-1)) / sqrt(Sum((x[i] - mean(x))**2, (i, 0, n-1)) * Sum((y[i] - mean(y))**2, (i, 0, n-1)))"
  definition: "Pearson product-moment correlation coefficient"
arguments:
  - name: x
    description: "First array"
  - name: y
    description: "Second array"
"""

# Load as Function class
corr_func = Function.from_string(correlation_yaml)

# Access SymPy expression
expr = corr_func.sympy_expression
expr

\(\displaystyle \frac{\sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(x \right)} + {x}_{i}\right) \left(- \operatorname{mean}{\left(y \right)} + {y}_{i}\right)}{\sqrt{\left(\sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(x \right)} + {x}_{i}\right)^{2}\right) \sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(y \right)} + {y}_{i}\right)^{2}}}\)

Generate Executable Code

# Generate JAX code using the render_code method
jax_code = corr_func.render_code(format="jax")
print(jax_code)
def correlation(x, y):
    """Pearson correlation coefficient"""
    return jnp.sum((-jnp.mean(x) + x)*(-jnp.mean(y) + y))/jnp.sqrt(jnp.sum((-jnp.mean(x) + x)**2)*jnp.sum((-jnp.mean(y) + y)**2))

Example: Loss Function with Aggregation

The loss function extends Function with aggregation over dimensions:

# YAML specification for loss function
loss_yaml = """
name: spectral_correlation
label: "Spectral Correlation Loss"
equation:
  rhs: "1 - correlation(simulated_psd, target_data)"
aggregate:
  over: node
  type: mean
description: >
  Pearson correlation between simulated and target power spectra.
  Loss computed per node, then averaged across all nodes.
"""

# Load as LossFunction class
loss_func = LossFunction.from_string(loss_yaml)

# View SymPy expression
loss_expr = loss_func.sympy_expression
loss_expr

\(\displaystyle 1 - \operatorname{correlation}{\left(simulated_{psd},target_{data} \right)}\)

Generate Loss Function Code with Aggregation

# Generate loss function code with aggregation
# inner_func_names tells the generator which functions are user-defined
loss_code = loss_func.render_code(inner_func_names=["correlation"])
print(loss_code)
def spectral_correlation(simulated_psd, target_data):
    """Pearson correlation between simulated and target power spectra. Loss computed per node, then averaged across all nodes.
"""
    def _per_element_loss(simulated_psd, target_data):
        return 1 - correlation(simulated_psd, target_data)

    # vmap over node dimension, then reduce with mean
    per_element_losses = jax.vmap(_per_element_loss)(simulated_psd, target_data)
    return jnp.mean(per_element_losses)

Execute with Real Data

Use the to_callable method to generate and execute the function:

import jax.numpy as jnp
import jax

# Create callable from the correlation function
correlation = corr_func.to_callable()

# Test with example data
x_sim = jnp.array([0.1, 0.5, 0.9, 0.7, 0.3])
x_emp = jnp.array([0.2, 0.6, 0.8, 0.6, 0.2])

corr = correlation(x_sim, x_emp)
print(f"Correlation: {corr:.4f}")
Correlation: 0.9428

Execute Loss Function with Multi-Node Data

import jax
# Create callable loss function
# Pass correlation function in namespace since it's referenced
loss_fn = loss_func.to_callable(
    inner_func_names=["correlation"],
    namespace={"correlation": correlation}
)

# Multi-node test data (84 regions, 100 frequency bins each)
n_nodes = 84
n_freqs = 100
key = jax.random.PRNGKey(42)

# Simulated and target PSDs
simulated_psd = jax.random.uniform(key, (n_nodes, n_freqs))
target_psd = simulated_psd + 0.1 * jax.random.normal(jax.random.PRNGKey(0), (n_nodes, n_freqs))

# Compute loss (scalar output!)
loss_value = loss_fn(simulated_psd, target_psd)
print(f"Loss value (mean over {n_nodes} nodes): {loss_value:.6f}")
print(f"Output shape: {loss_value.shape} (scalar)")
Loss value (mean over 84 nodes): 0.055940
Output shape: () (scalar)

Class Methods Reference

Function

Method Description
from_file(path) Load from YAML file
from_string(yaml_str) Load from YAML string
render_code(format='jax') Generate Python code
to_jax() Shorthand for render_code(format='jax')
to_numpy() Shorthand for render_code(format='numpy')
to_callable() Generate and return executable function
sympy_expression Property: parsed SymPy expression
latex Property: LaTeX representation

LossFunction

Inherits all Function methods, plus:

Method Description
render_code(inner_func_names=[]) Generate loss code with aggregation
to_callable(inner_func_names=[], namespace={}) Generate executable loss function
aggregation_type Property: ‘mean’, ‘sum’, etc.
aggregation_dimension Property: ‘node’, ‘time’, etc.

Aggregation Reference

For multiple regions, the loss aggregates per-node losses:

\[L = \frac{1}{N} \sum_{i=0}^{N-1} \left(1 - \text{corr}(x_i, y_i)\right)\]

aggregate:
  over: node
  type: mean
aggregate.type Mathematical Generated Code
mean \(\frac{1}{N}\sum_i f(x_i)\) jnp.mean(jax.vmap(f)(x))
sum \(\sum_i f(x_i)\) jnp.sum(jax.vmap(f)(x))