probly.data_generation.jax_generator

JAX data generator implementation.

Runs a JAX model over input arrays, collects simple statistics, and provides helpers to persist results.

Classes

JAXDataGenerator(model, dataset[, ...])

Data generator for JAX models.

class probly.data_generation.jax_generator.JAXDataGenerator(model, dataset, batch_size=32, device=None)[source]

Bases: BaseDataGenerator[Callable[[Array], Array], tuple[object, object], str | None]

Data generator for JAX models.

Parameters:
  • model (Callable[[jnp.ndarray], jnp.ndarray])

  • dataset (tuple[object, object])

  • batch_size (int)

  • device (str | None)

generate()[source]

Run the model on the dataset and compute basic metrics.

Return type:

dict[str, Any]

get_info()

Return a summary of the generator configs.

Return type:

dict[str, Any]

load(path)

Load results from a file.

Parameters:

path (str)

Return type:

dict[str, Any]

save(path)

Save generated results to a file.

Parameters:

path (str)

Return type:

None