probly.utils.torch

Utility functions for PyTorch models.

Functions

temperature_softmax(logits, temperature)

Compute the softmax of logits with temperature scaling applied.

torch_collect_outputs(model, loader, device)

Collect outputs and targets from a model for a given data loader.

torch_reset_all_parameters(module)

Reset all parameters of a torch module.

probly.utils.torch.temperature_softmax(logits, temperature)[source]

Compute the softmax of logits with temperature scaling applied.

Computes the softmax based on the logits divided by the temperature. Assumes that the last dimension of logits is the class dimension.

Parameters:
  • logits (Tensor) – torch.Tensor, shape (n_instances, n_classes), logits to apply softmax on

  • temperature (float | Tensor) – float, temperature scaling factor

Returns:

torch.Tensor, shape (n_instances, n_classes), softmax of logits with temperature scaling applied

Return type:

ts

probly.utils.torch.torch_collect_outputs(model, loader, device)[source]

Collect outputs and targets from a model for a given data loader.

Parameters:
  • model (Module) – torch.nn.Module, model to collect outputs from

  • loader (DataLoader) – torch.utils.data.DataLoader, data loader to collect outputs from

  • device (device) – torch.device, device to move data to

Returns:

torch.Tensor, shape (n_instances, n_classes), model outputs targets: torch.Tensor, shape (n_instances,), target labels

Return type:

outputs

probly.utils.torch.torch_reset_all_parameters(module)[source]

Reset all parameters of a torch module.

Parameters:

module (Module) – torch.nn.Module, module to reset parameters

Return type:

None