Source code for probly.utils.torch
"""Utility functions for PyTorch models."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from tqdm import tqdm
[docs]
@torch.no_grad()
def torch_collect_outputs(
model: torch.nn.Module,
loader: torch.utils.data.DataLoader,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Collect outputs and targets from a model for a given data loader.
Args:
model: torch.nn.Module, model to collect outputs from
loader: torch.utils.data.DataLoader, data loader to collect outputs from
device: torch.device, device to move data to
Returns:
outputs: torch.Tensor, shape (n_instances, n_classes), model outputs
targets: torch.Tensor, shape (n_instances,), target labels
"""
outputs = torch.empty(0, device=device)
targets = torch.empty(0, device=device)
for inpt, target in tqdm(loader, desc="Batches"):
outputs = torch.cat((outputs, model(inpt.to(device))), dim=0)
targets = torch.cat((targets, target.to(device)), dim=0)
return outputs, targets
[docs]
def torch_reset_all_parameters(module: torch.nn.Module) -> None:
"""Reset all parameters of a torch module.
Args:
module: torch.nn.Module, module to reset parameters
"""
if hasattr(module, "reset_parameters"):
module.reset_parameters() # type: ignore[operator]
for child in module.children():
if hasattr(child, "reset_parameters"):
child.reset_parameters() # type: ignore[operator]
[docs]
def temperature_softmax(logits: torch.Tensor, temperature: float | torch.Tensor) -> torch.Tensor:
"""Compute the softmax of logits with temperature scaling applied.
Computes the softmax based on the logits divided by the temperature. Assumes that the last dimension
of logits is the class dimension.
Args:
logits: torch.Tensor, shape (n_instances, n_classes), logits to apply softmax on
temperature: float, temperature scaling factor
Returns:
ts: torch.Tensor, shape (n_instances, n_classes), softmax of logits with temperature scaling applied
"""
ts = F.softmax(logits / temperature, dim=-1)
return ts