probly.data_generation.factory¶
Factory for creating framework-specific data generators.
Selects an appropriate BaseDataGenerator implementation based on the framework argument (“pytorch”, “tensorflow”, or “jax”).
Functions
Create a data generator based on the selected framework. |
- probly.data_generation.factory.create_data_generator(framework: Literal['pytorch'], model: torch.nn.Module, dataset: TorchDataset[Any], batch_size: int = 32, device: str | None = None) PyTorchDataGenerator[source]¶
- probly.data_generation.factory.create_data_generator(framework: Literal['jax'], model: Callable[[Any], Any], dataset: tuple[Any, Any], batch_size: int = 32, device: str | None = None) JAXDataGenerator
Create a data generator based on the selected framework.