utils.utils

utils.utils

Functions

Name Description
broadcast_1d_array Broadcast a 1D array of shape (N,) to shape (N, *additional_dims)
safe_reshape A safe reshaping function with the following properties:

broadcast_1d_array

utils.utils.broadcast_1d_array(arr_1d, additional_dims=())

Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) with a single reshape operation.

Parameters:

arr_1d : numpy.ndarray or jax.numpy.ndarray 1D input array of shape (N,) additional_dims : tuple Additional dimensions to broadcast to. Can be empty tuple () for no additional dimensions.

Returns:

numpy.ndarray or jax.numpy.ndarray Broadcasted array of shape (N, *additional_dims)

safe_reshape

utils.utils.safe_reshape(arr, new_shape, fill_value=jnp.nan)

A safe reshaping function with the following properties: - If new_shape has fewer elements than arr, raises an error - If new_shape has equal elements to arr, performs standard reshape - If new_shape requires more elements than arr, fills extra space with fill_value

Args: arr: JAX array to reshape new_shape: Tuple of integers specifying the new shape fill_value: Value to use for filling extra elements (default: jnp.nan)

Returns: Reshaped JAX array