probly.data_generation.jax_generator¶
JAX data generator implementation.
Runs a JAX model over input arrays, collects simple statistics, and provides helpers to persist results.
Classes
|
Data generator for JAX models. |
- class probly.data_generation.jax_generator.JAXDataGenerator(model, dataset, batch_size=32, device=None)[source]¶
Bases:
BaseDataGenerator[Callable[[Array],Array],tuple[object,object],str|None]Data generator for JAX models.
- Parameters: