function

knowledge.function

Function Classes

Extended Function and LossFunction classes with code generation methods. Inherits from the LinkML datamodel and adds rendering/execution capabilities.

Usage

From YAML file:

from tvbo import Function, LossFunction

func = Function.from_file("correlation.yaml")
code = func.render_code(format='jax')
callable_fn = func.to_callable()

From YAML string:

func = Function.from_string(yaml_string)

From datamodel object:

from tvbo.datamodel import tvbo_datamodel
dm_func = tvbo_datamodel.Function(name='sigmoid', ...)
func = Function.from_datamodel(dm_func)

Classes

Name Description
Function Extended Function class with code generation and execution methods.
LossFunction Extended LossFunction class with code generation and execution methods.

Function

knowledge.function.Function(name='Function', **kwargs)

Extended Function class with code generation and execution methods.

Inherits all schema fields from tvbo_datamodel.Function and adds: - Factory constructors: from_file, from_string, from_datamodel - Code generation: render_code, to_jax, to_numpy - Execution: to_callable

Attributes

Name Description
latex Return LaTeX representation of the function equation.
sympy_expression Return the parsed SymPy expression for this function’s equation.

Methods

Name Description
from_datamodel Create Function from a tvbo_datamodel.Function instance.
from_file Load Function from a YAML file.
from_string Load Function from a YAML string.
render_code Generate Python code for this function.
to_callable Generate and execute function code, returning the callable.
to_jax Generate JAX code for this function.
to_numpy Generate NumPy code for this function.
to_python Generate pure Python code for this function.
from_datamodel
knowledge.function.Function.from_datamodel(func)

Create Function from a tvbo_datamodel.Function instance.

from_file
knowledge.function.Function.from_file(path)

Load Function from a YAML file.

from_string
knowledge.function.Function.from_string(yaml_str)

Load Function from a YAML string.

render_code
knowledge.function.Function.render_code(
    format='jax',
    user_functions=None,
    render_func=None,
)

Generate Python code for this function.

Parameters

format : str Output format: ‘jax’, ‘numpy’, ‘python’ user_functions : dict, optional Custom function name mappings for the printer. Example: {‘sigmoid’: ‘sigmoid’} to preserve function name render_func : callable, optional Custom render function for model context.

Returns

str Python code string defining the function

Examples

func = Function.from_string(yaml_str) print(func.render_code()) def sigmoid(x): return 1/(1 + jnp.exp(-x))

to_callable
knowledge.function.Function.to_callable(
    format='jax',
    user_functions=None,
    namespace=None,
)

Generate and execute function code, returning the callable.

Parameters

format : str Output format: ‘jax’, ‘numpy’ user_functions : dict, optional Custom function name mappings namespace : dict, optional Namespace for exec(). If None, creates one with jnp/np imports.

Returns

callable The generated function as a callable

Examples

func = Function.from_string(sigmoid_yaml) sigmoid = func.to_callable() sigmoid(0.0) 0.5

to_jax
knowledge.function.Function.to_jax(**kwargs)

Generate JAX code for this function.

to_numpy
knowledge.function.Function.to_numpy(**kwargs)

Generate NumPy code for this function.

to_python
knowledge.function.Function.to_python(**kwargs)

Generate pure Python code for this function.

LossFunction

knowledge.function.LossFunction(name='LossFunction', **kwargs)

Extended LossFunction class with code generation and execution methods.

Inherits all schema fields from tvbo_datamodel.LossFunction and adds: - Factory constructors: from_file, from_string, from_datamodel - Code generation: render_code, to_jax, to_numpy - Execution: to_callable

LossFunction extends Function with aggregation specification for per-element losses (e.g., mean over nodes).

Attributes

Name Description
aggregation_dimension Return the aggregation dimension as a string (e.g., ‘node’).
aggregation_type Return the aggregation type as a string (e.g., ‘mean’, ‘sum’).
latex Return LaTeX representation of the loss function equation.
sympy_expression Return the parsed SymPy expression for this function’s equation.

Methods

Name Description
from_datamodel Create LossFunction from a tvbo_datamodel.LossFunction instance.
from_file Load LossFunction from a YAML file.
from_string Load LossFunction from a YAML string.
render_code Generate Python code for this loss function with aggregation.
to_callable Generate and execute loss function code, returning the callable.
to_jax Generate JAX code for this loss function.
to_numpy Generate NumPy code for this loss function.
from_datamodel
knowledge.function.LossFunction.from_datamodel(func)

Create LossFunction from a tvbo_datamodel.LossFunction instance.

from_file
knowledge.function.LossFunction.from_file(path)

Load LossFunction from a YAML file.

from_string
knowledge.function.LossFunction.from_string(yaml_str)

Load LossFunction from a YAML string.

render_code
knowledge.function.LossFunction.render_code(
    format='jax',
    user_functions=None,
    inner_func_names=None,
)

Generate Python code for this loss function with aggregation.

Parameters

format : str Output format: ‘jax’, ‘numpy’ user_functions : dict, optional Custom function name mappings inner_func_names : list, optional Names of inner functions that should be recognized. Example: [‘correlation’] for “1 - correlation(x, y)”

Returns

str Python code string defining the loss function

Examples

loss = LossFunction.from_string(loss_yaml) print(loss.render_code(inner_func_names=[‘correlation’])) def spectral_loss(sim, target): def _per_element_loss(sim, target): return 1 - correlation(sim, target) per_element_losses = jax.vmap(_per_element_loss)(sim, target) return jnp.mean(per_element_losses)

to_callable
knowledge.function.LossFunction.to_callable(
    format='jax',
    user_functions=None,
    inner_func_names=None,
    namespace=None,
)

Generate and execute loss function code, returning the callable.

Parameters

format : str Output format: ‘jax’, ‘numpy’ user_functions : dict, optional Custom function name mappings inner_func_names : list, optional Names of inner functions to recognize namespace : dict, optional Namespace for exec(). If None, creates one with jnp/np/jax imports.

Returns

callable The generated loss function as a callable

to_jax
knowledge.function.LossFunction.to_jax(**kwargs)

Generate JAX code for this loss function.

to_numpy
knowledge.function.LossFunction.to_numpy(**kwargs)

Generate NumPy code for this loss function.