probly.data_generation.torch_first_order_generator¶
Torch FirstOrder data generator.
Functions
|
Load distributions from a torch binary file (.pt / .pth). |
|
Creates DataLoader pairing inputs (labels if any available) with first-order distribs. |
|
Save distributions to a torch binary file (.pt / .pth). |
|
Move tensor/nested tensors to the specified device. |
Classes
|
Version First-Order data generator. |
|
Wrap an existing dataset (like base_dataset) with first-order distributions for training/eval. |
- class probly.data_generation.torch_first_order_generator.FirstOrderDataGenerator(model, device='cpu', batch_size=64, output_mode='auto', output_transform=None, input_getter=None, model_name=None, return_numpy=True, return_torch=True)[source]¶
Bases:
FirstOrderDataGeneratorVersion First-Order data generator.
- Parameters:
model (torch.nn.Module | Callable[..., Any]) – A Callable that maps a batch of inputs to logits or probs. Normally a torch.nn.Module.
device (str) – Device for inference (e.g., ‘cpu’ or ‘cuda’). Default ‘cpu’.
batch_size (int) – Batch size to use when wrapping a Dataset. (Default now down 64 instead of 128.)
output_mode (str) – One of {‘auto’, ‘logits’, ‘probs’}. If ‘auto’, attempt to detect whether outputs are logits or probabilities. If ‘logits’, apply softmax. If ‘probs’, use as is. Default of course ‘auto’.
output_transform (Callable[[object], torch.Tensor] | None) – func to convert raw model output to probs. If called this is over output_mode.
input_getter (Callable[[Any], Any] | None) – func to extract model input from dataset item. Signature: input_getter(sample) -> model_input When None expects dataset items to be (input, target) or input only.
model_name (str | None) – Optional string identifier. (saved with metadata)
return_numpy (bool)
return_torch (bool)
- generate_distributions(dataset_or_loader, *, progress=True)[source]¶
Generate per-sample probability distributions.
- Parameters:
dataset_or_loader (object) – A torch.utils.data.Dataset or torch.utils.data.DataLoader. Items should be tensors or tuples/dicts that have tensors.
progress (bool) – If True prints simple progress information in terminal output for user to see that progress is happening.
Returns
-------
dict[int – Mapping from dataset index to list of probabilities.
list[float]] – Mapping from dataset index to list of probabilities.
- Return type:
- get_posterior_distributions()[source]¶
Extracts u and p from all BayesLinear layers — issue #241.
Returns dict compatible with future torch.save/load.
- load_distributions(path)[source]¶
Load distributions from JSON and return Torch tensors.
Returns:¶
- (distributions, meta)
distributions: dict[int, torch.Tensor] meta: dict with any metadata saved alongside distributions
- prepares_batch_inp(sample)[source]¶
Extract the model input from a dataset sample or batch.
Behavior: - If input_getter is provided use it. - If the sample/batch is a tuple or list like (inputs, labels, …),
return the first element (inputs).
Otherwise return the sample as-is.
- save_distributions(path, distributions, *, meta=None)[source]¶
Save distributions and optional metadata as JSON.
- model: torch.nn.Module | Callable[..., Any]¶
- output_transform: Callable[[object], torch.Tensor] | None = None¶
- class probly.data_generation.torch_first_order_generator.FirstOrderDataset(base_dataset, distributions, input_getter=None)[source]¶
Bases:
DatasetWrap an existing dataset (like base_dataset) with first-order distributions for training/eval.
Returns items as (input, distribution) if the base dataset yields only input, or (input, label, distribution) if the base dataset yields (input, label).
- probly.data_generation.torch_first_order_generator.load_distributions_pt(load_path, *, device=None, verbose=True)[source]¶
Load distributions from a torch binary file (.pt / .pth).
- probly.data_generation.torch_first_order_generator.output_dataloader(base_dataset, distributions, *, batch_size=64, shuffle=False, num_workers=0, pin_memory=False, input_getter=None)[source]¶
Creates DataLoader pairing inputs (labels if any available) with first-order distribs.