probly.utils.torch¶
Utility functions for PyTorch models.
Functions
|
Compute the softmax of logits with temperature scaling applied. |
|
Collect outputs and targets from a model for a given data loader. |
|
Reset all parameters of a torch module. |
- probly.utils.torch.temperature_softmax(logits, temperature)[source]¶
Compute the softmax of logits with temperature scaling applied.
Computes the softmax based on the logits divided by the temperature. Assumes that the last dimension of logits is the class dimension.
- probly.utils.torch.torch_collect_outputs(model, loader, device)[source]¶
Collect outputs and targets from a model for a given data loader.
- Parameters:
model (Module) – torch.nn.Module, model to collect outputs from
loader (DataLoader) – torch.utils.data.DataLoader, data loader to collect outputs from
device (device) – torch.device, device to move data to
- Returns:
torch.Tensor, shape (n_instances, n_classes), model outputs targets: torch.Tensor, shape (n_instances,), target labels
- Return type:
outputs