utils.utils

utils.utils

Functions

Name Description
broadcast_1d_array Broadcast a 1D array of shape (N,) to shape (N, *additional_dims)
format_pytree_as_string Recursively formats a JAX pytree structure as a string with Unicode box-drawing characters.
pretty_print_pytree Prints a pretty formatted representation of a JAX pytree structure.
safe_reshape A safe reshaping function with the following properties:

broadcast_1d_array

utils.utils.broadcast_1d_array(arr_1d, additional_dims=())

Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) with a single reshape operation.

Parameters:

arr_1d : numpy.ndarray or jax.numpy.ndarray 1D input array of shape (N,) additional_dims : tuple Additional dimensions to broadcast to. Can be empty tuple () for no additional dimensions.

Returns:

numpy.ndarray or jax.numpy.ndarray Broadcasted array of shape (N, *additional_dims)

format_pytree_as_string

utils.utils.format_pytree_as_string(
    pytree,
    name='root',
    prefix='',
    is_last=False,
    show_numerical_only=False,
    is_root=True,
    hide_none=False,
    show_array_values=False,
)

Recursively formats a JAX pytree structure as a string with Unicode box-drawing characters.

Args: pytree: The pytree to format name: The name of the current node prefix: Current line prefix is_last: Whether the current node is the last child of its parent show_numerical_only: If True, only show arrays and numerical types (float, int, etc.) is_root: Whether this node is the root of the tree hide_none: If True, fields with None values will be hidden

Returns: str: The formatted string representation of the pytree

pretty_print_pytree

utils.utils.pretty_print_pytree(
    pytree,
    name='root',
    prefix='',
    show_numerical_only=False,
    hide_none=False,
)

Prints a pretty formatted representation of a JAX pytree structure.

Args: pytree: The pytree to print name: The name of the current node prefix: Current line prefix show_numerical_only: If True, only show arrays and numerical types (float, int, etc.) hide_none: If True, fields with None values will be hidden

safe_reshape

utils.utils.safe_reshape(arr, new_shape, fill_value=jnp.nan)

A safe reshaping function with the following properties: - If new_shape has fewer elements than arr, raises an error - If new_shape has equal elements to arr, performs standard reshape - If new_shape requires more elements than arr, fills extra space with fill_value

Args: arr: JAX array to reshape new_shape: Tuple of integers specifying the new shape fill_value: Value to use for filling extra elements (default: jnp.nan)

Returns: Reshaped JAX array