probly.utils.torch

Utility functions for PyTorch models.

Functions

intersection_probability

Intersection probability of a probability interval, per [WCM+24] Section 3.4.

temperature_softmax

Compute the softmax of logits with temperature scaling applied.

torch_collect_outputs

Collect outputs and targets from a model for a given data loader.

torch_entropy

Shannon entropy H(p) computed in torch along the last dim; 0*log(0) treated as 0.

torch_reset_all_parameters

Reset all parameters of a torch module.