utils.utils
utils.utils
Functions
Name | Description |
---|---|
broadcast_1d_array | Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) |
safe_reshape | A safe reshaping function with the following properties: |
broadcast_1d_array
=()) utils.utils.broadcast_1d_array(arr_1d, additional_dims
Broadcast a 1D array of shape (N,) to shape (N, *additional_dims) with a single reshape operation.
Parameters:
arr_1d : numpy.ndarray or jax.numpy.ndarray 1D input array of shape (N,) additional_dims : tuple Additional dimensions to broadcast to. Can be empty tuple () for no additional dimensions.
Returns:
numpy.ndarray or jax.numpy.ndarray Broadcasted array of shape (N, *additional_dims)
safe_reshape
=jnp.nan) utils.utils.safe_reshape(arr, new_shape, fill_value
A safe reshaping function with the following properties: - If new_shape has fewer elements than arr, raises an error - If new_shape has equal elements to arr, performs standard reshape - If new_shape requires more elements than arr, fills extra space with fill_value
Args: arr: JAX array to reshape new_shape: Tuple of integers specifying the new shape fill_value: Value to use for filling extra elements (default: jnp.nan)
Returns: Reshaped JAX array