probly.data_generation.jax_first_order_generator

JAX FirstOrder data generator.

Functions

output_dataloader(base_dataset, distributions, *)

Create a JAX-native loader yielding JAX arrays for distributions.

Classes

FirstOrderDataGenerator(model[, device, ...])

JAX-native FirstOrder data generator.

FirstOrderDataset(base_dataset, distributions)

Subclass the Python-first dataset, converting distributions to jnp arrays.

JAXOutputDataLoader(dataset, *[, ...])

JAX-native output loader yielding batches with JAX arrays.

class probly.data_generation.jax_first_order_generator.FirstOrderDataGenerator(model, device='cpu', batch_size=64, output_mode='auto', output_transform=None, input_getter=None, model_name=None, return_numpy=True, return_jax=True)[source]

Bases: FirstOrderDataGenerator

JAX-native FirstOrder data generator.

Parameters:
  • model (Callable[..., Any]) – Callable that maps a batch of inputs to logits or probs. Typically a JAX-transformed function that accepts jnp.ndarray inputs and returns jnp.ndarray outputs.

  • device (str) – Target device platform (e.g., cpu, gpu, tpu). Default cpu.

  • batch_size (int) – Batch size to use when wrapping Dataset.

  • output_mode (str) – One of {auto, logits, probs}. If auto, attempt to detect whether outputs are logits or probabilities. If logits apply softmax. If probs use as is.

  • output_transform (Callable[[object], jnp.ndarray] | None) – Function to convert raw model output to probs. Rem:Overrides output_mode when provided!

  • input_getter (Callable[[Any], Any] | None) – Function to extract model input from a dataset item. Signature: input_getter(sample) -> model_input When None expects dataset items to be (input, target) or input only.

  • model_name (str | None) – Optional string identifier (saved with metadata).

  • return_numpy (bool)

  • return_jax (bool)

generate_distributions(dataset_or_loader, *, progress=True)[source]

Generate per-sample probs distribs for a dataset or loader.

Returns a dict mapping dataset indices to either jnp.ndarray rows (when return_jax is True) or lists of floats.

Parameters:
Return type:

object

load_distributions(path)[source]

Load distributions and convert to JAX arrays.

Returns:

(distributions, meta)

distributions: dict[int, jnp.ndarray] meta: dict with any metadata saved alongside distributions

Parameters:

path (str | Path)

Return type:

tuple[dict[int, jnp.ndarray], dict[str, Any]]

prepares_batch_inp(sample)[source]

Extract the model input from a dataset sample.

Behavior: - If input_getter is provided use it to obtain the input. - If the sample is a tuple like (input, label, …), return the first element. - Otherwise, return the sample as-is.

Notes: - Lists are treated as input-only feature vectors and are NOT unpacked.

Parameters:

sample (object)

Return type:

object

save_distributions(path, distributions, *, meta=None)[source]

Save distributions and optional metadata as JSON.

Parameters:
  • path (str | Path)

  • distributions (Mapping[int, Iterable[float]])

  • meta (dict[str, Any] | None)

Return type:

None

to_device(x)[source]

Move arrays/lists/dicts to configured JAX device if available.

If no matching device is found, returns the input unchanged.

Parameters:

x (object)

Return type:

object

to_probs(outputs)[source]

Convert model outputs to probabilities as jnp.ndarray or lists.

Parameters:

outputs (object)

Return type:

object

batch_size: int = 64
device: str = 'cpu'
input_getter: Callable[[Any], Any] | None = None
model: Callable[..., Any]
model_name: str | None = None
output_mode: str = 'auto'
output_transform: Callable[[object], jnp.ndarray] | None = None
return_jax: bool = True
return_numpy: bool = True
class probly.data_generation.jax_first_order_generator.FirstOrderDataset(base_dataset, distributions, input_getter=None, return_numpy=True)[source]

Bases: FirstOrderDataset

Subclass the Python-first dataset, converting distributions to jnp arrays.

Parameters:
class probly.data_generation.jax_first_order_generator.JAXOutputDataLoader(dataset, *, batch_size=64, shuffle=False, seed=None, device=None)[source]

Bases: object

JAX-native output loader yielding batches with JAX arrays.

Yields per-batch tuples of (inputs, distributions) or (inputs, labels, distributions) depending on whether the dataset has labels. Distributions are stacked as jnp.ndarray in shape [batch, classes]. Inputs are best-effort converted to jnp.ndarray, if conversion is not possible, the original Python sequence is returned.

Shuffling uses jax.random.permutation. Optional device placement moves any JAX arrays to the selected device via jax.device_put.

Parameters:
probly.data_generation.jax_first_order_generator.output_dataloader(base_dataset, distributions, *, batch_size=64, shuffle=False, num_workers=0, pin_memory=False, input_getter=None, seed=None, device=None)[source]

Create a JAX-native loader yielding JAX arrays for distributions.

Parameters num_workers and pin_memory are kept for API parity but ignored ! Use seed to control shuffle permutation and device to place arrays (e.g., “cpu”, “gpu”, “tpu” or more specific like “gpu:0”).

Parameters:
  • base_dataset (DatasetLike)

  • distributions (Mapping[int, Iterable[float]])

  • batch_size (int)

  • shuffle (bool)

  • num_workers (int)

  • pin_memory (bool)

  • input_getter (Callable[[Any], Any] | None)

  • seed (int | None)

  • device (str | None)

Return type:

JAXOutputDataLoader