spaces.LogGridAxis

spaces.LogGridAxis(low, high, n, shape=None)

Axis for systematic logarithmic grid sampling over parameter bounds.

Generates logarithmically spaced values between low and high bounds using deterministic grid sampling. Useful for parameters that span multiple orders of magnitude (e.g., learning rates, regularization coefficients).

Parameters

Name Type Description Default
low float Lower bound for sampling (must be positive). required
high float Upper bound for sampling (must be positive). required
n int Number of grid points to generate. required

Raises

Name Type Description
ValueError If n <= 0, low >= high, or if low or high are not positive.

Examples

>>> log_grid = LogGridAxis(0.001, 1.0, 5)
>>> values = log_grid.generate_values()
>>> print(values)  # [0.001, 0.00562, 0.0316, 0.178, 1.0] (approximately)

Attributes

Name Description
size Number of grid points.

Methods

Name Description
generate_values Generate logarithmically spaced grid values.

generate_values

spaces.LogGridAxis.generate_values(key=None)

Generate logarithmically spaced grid values.

Parameters

Name Type Description Default
key jax.random.PRNGKey Random key. Ignored for deterministic grid sampling. None

Returns

Name Type Description
jnp.ndarray Array of logarithmically spaced values from low to high. If shape is specified, values are broadcast to shape (n,) + shape with identical values across additional dimensions.