from tvbo import Function, LossFunctionLoss Function Specification
Overview
TVBO uses SymPy for symbolic mathematics, enabling:
- Mathematical notation in YAML specs
- LaTeX rendering for documentation
- Code generation to JAX/NumPy
- Direct execution via class methods
Loading Functions from YAML
Functions are defined in YAML and loaded using the extended Function and LossFunction classes:
Example: Correlation Function
From JR_tvboptim.qmd:
# YAML specification for correlation function
correlation_yaml = """
name: correlation
description: "Pearson correlation coefficient"
equation:
rhs: "Sum((x[i] - mean(x)) * (y[i] - mean(y)), (i, 0, n-1)) / sqrt(Sum((x[i] - mean(x))**2, (i, 0, n-1)) * Sum((y[i] - mean(y))**2, (i, 0, n-1)))"
definition: "Pearson product-moment correlation coefficient"
arguments:
- name: x
description: "First array"
- name: y
description: "Second array"
"""
# Load as Function class
corr_func = Function.from_string(correlation_yaml)
# Access SymPy expression
expr = corr_func.sympy_expression
expr\(\displaystyle \frac{\sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(x \right)} + {x}_{i}\right) \left(- \operatorname{mean}{\left(y \right)} + {y}_{i}\right)}{\sqrt{\left(\sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(x \right)} + {x}_{i}\right)^{2}\right) \sum_{i=0}^{n - 1} \left(- \operatorname{mean}{\left(y \right)} + {y}_{i}\right)^{2}}}\)
Generate Executable Code
# Generate JAX code using the render_code method
jax_code = corr_func.render_code(format="jax")
print(jax_code)def correlation(x, y):
"""Pearson correlation coefficient"""
return jnp.sum((-jnp.mean(x) + x)*(-jnp.mean(y) + y))/jnp.sqrt(jnp.sum((-jnp.mean(x) + x)**2)*jnp.sum((-jnp.mean(y) + y)**2))
Example: Loss Function with Aggregation
The loss function extends Function with aggregation over dimensions:
# YAML specification for loss function
loss_yaml = """
name: spectral_correlation
label: "Spectral Correlation Loss"
equation:
rhs: "1 - correlation(simulated_psd, target_data)"
aggregate:
over: node
type: mean
description: >
Pearson correlation between simulated and target power spectra.
Loss computed per node, then averaged across all nodes.
"""
# Load as LossFunction class
loss_func = LossFunction.from_string(loss_yaml)
# View SymPy expression
loss_expr = loss_func.sympy_expression
loss_expr\(\displaystyle 1 - \operatorname{correlation}{\left(simulated_{psd},target_{data} \right)}\)
Generate Loss Function Code with Aggregation
# Generate loss function code with aggregation
# inner_func_names tells the generator which functions are user-defined
loss_code = loss_func.render_code(inner_func_names=["correlation"])
print(loss_code)def spectral_correlation(simulated_psd, target_data):
"""Pearson correlation between simulated and target power spectra. Loss computed per node, then averaged across all nodes.
"""
def _per_element_loss(simulated_psd, target_data):
return 1 - correlation(simulated_psd, target_data)
# vmap over node dimension, then reduce with mean
per_element_losses = jax.vmap(_per_element_loss)(simulated_psd, target_data)
return jnp.mean(per_element_losses)
Execute with Real Data
Use the to_callable method to generate and execute the function:
import jax.numpy as jnp
import jax
# Create callable from the correlation function
correlation = corr_func.to_callable()
# Test with example data
x_sim = jnp.array([0.1, 0.5, 0.9, 0.7, 0.3])
x_emp = jnp.array([0.2, 0.6, 0.8, 0.6, 0.2])
corr = correlation(x_sim, x_emp)
print(f"Correlation: {corr:.4f}")Correlation: 0.9428
Execute Loss Function with Multi-Node Data
import jax
# Create callable loss function
# Pass correlation function in namespace since it's referenced
loss_fn = loss_func.to_callable(
inner_func_names=["correlation"],
namespace={"correlation": correlation}
)
# Multi-node test data (84 regions, 100 frequency bins each)
n_nodes = 84
n_freqs = 100
key = jax.random.PRNGKey(42)
# Simulated and target PSDs
simulated_psd = jax.random.uniform(key, (n_nodes, n_freqs))
target_psd = simulated_psd + 0.1 * jax.random.normal(jax.random.PRNGKey(0), (n_nodes, n_freqs))
# Compute loss (scalar output!)
loss_value = loss_fn(simulated_psd, target_psd)
print(f"Loss value (mean over {n_nodes} nodes): {loss_value:.6f}")
print(f"Output shape: {loss_value.shape} (scalar)")Loss value (mean over 84 nodes): 0.055940
Output shape: () (scalar)
Class Methods Reference
Function
| Method | Description |
|---|---|
from_file(path) |
Load from YAML file |
from_string(yaml_str) |
Load from YAML string |
render_code(format='jax') |
Generate Python code |
to_jax() |
Shorthand for render_code(format='jax') |
to_numpy() |
Shorthand for render_code(format='numpy') |
to_callable() |
Generate and return executable function |
sympy_expression |
Property: parsed SymPy expression |
latex |
Property: LaTeX representation |
LossFunction
Inherits all Function methods, plus:
| Method | Description |
|---|---|
render_code(inner_func_names=[]) |
Generate loss code with aggregation |
to_callable(inner_func_names=[], namespace={}) |
Generate executable loss function |
aggregation_type |
Property: ‘mean’, ‘sum’, etc. |
aggregation_dimension |
Property: ‘node’, ‘time’, etc. |
Aggregation Reference
For multiple regions, the loss aggregates per-node losses:
\[L = \frac{1}{N} \sum_{i=0}^{N-1} \left(1 - \text{corr}(x_i, y_i)\right)\]
aggregate:
over: node
type: meanaggregate.type |
Mathematical | Generated Code |
|---|---|---|
mean |
\(\frac{1}{N}\sum_i f(x_i)\) | jnp.mean(jax.vmap(f)(x)) |
sum |
\(\sum_i f(x_i)\) | jnp.sum(jax.vmap(f)(x)) |