probly.data_generation.pytorch_generator

PyTorch data generator implementation.

Runs a PyTorch model over a dataset, collects simple statistics, and provides helpers to persist results.

Classes

PyTorchDataGenerator(model, dataset[, ...])

Data generator for PyTorch models.

class probly.data_generation.pytorch_generator.PyTorchDataGenerator(model, dataset, batch_size=32, device=None, num_workers=0)[source]

Bases: BaseDataGenerator[Module, Dataset, str | None]

Data generator for PyTorch models.

Parameters:
generate()[source]

Run the model on the dataset and compute basic metrics.

Return type:

dict[str, Any]

get_info()

Return a summary of the generator configs.

Return type:

dict[str, Any]

load(path)

Load results from a file.

Parameters:

path (str)

Return type:

dict[str, Any]

save(path)

Save generated results to a file.

Parameters:

path (str)

Return type:

None