utils.utils
utils.utils
Functions
| Name | Description |
|---|---|
| broadcast_1d_array | Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) |
| format_pytree_as_string | Recursively formats a JAX pytree structure as a string with Unicode box-drawing characters. |
| pretty_print_pytree | Prints a pretty formatted representation of a JAX pytree structure. |
| safe_reshape | A safe reshaping function with the following properties: |
broadcast_1d_array
utils.utils.broadcast_1d_array(arr_1d, additional_dims=())Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) with a single reshape operation.
Parameters:
arr_1d : numpy.ndarray or jax.numpy.ndarray 1D input array of shape (N,) additional_dims : tuple Additional dimensions to broadcast to. Can be empty tuple () for no additional dimensions.
Returns:
numpy.ndarray or jax.numpy.ndarray Broadcasted array of shape (N, *additional_dims)
format_pytree_as_string
utils.utils.format_pytree_as_string(
pytree,
name='root',
prefix='',
is_last=False,
show_numerical_only=False,
is_root=True,
hide_none=False,
show_array_values=False,
)Recursively formats a JAX pytree structure as a string with Unicode box-drawing characters.
Args: pytree: The pytree to format name: The name of the current node prefix: Current line prefix is_last: Whether the current node is the last child of its parent show_numerical_only: If True, only show arrays and numerical types (float, int, etc.) is_root: Whether this node is the root of the tree hide_none: If True, fields with None values will be hidden
Returns: str: The formatted string representation of the pytree
pretty_print_pytree
utils.utils.pretty_print_pytree(
pytree,
name='root',
prefix='',
show_numerical_only=False,
hide_none=False,
)Prints a pretty formatted representation of a JAX pytree structure.
Args: pytree: The pytree to print name: The name of the current node prefix: Current line prefix show_numerical_only: If True, only show arrays and numerical types (float, int, etc.) hide_none: If True, fields with None values will be hidden
safe_reshape
utils.utils.safe_reshape(arr, new_shape, fill_value=jnp.nan)A safe reshaping function with the following properties: - If new_shape has fewer elements than arr, raises an error - If new_shape has equal elements to arr, performs standard reshape - If new_shape requires more elements than arr, fills extra space with fill_value
Args: arr: JAX array to reshape new_shape: Tuple of integers specifying the new shape fill_value: Value to use for filling extra elements (default: jnp.nan)
Returns: Reshaped JAX array