probly.train.evidential.torch.postnet_loss¶
- probly.train.evidential.torch.postnet_loss(alpha: Tensor, y: Tensor, entropy_weight: float = 1e-05, reduction: str = 'sum') Tensor[source]¶
Posterior Networks (PostNet) classification loss.
Implements the expected cross-entropy loss with an entropy regularizer as proposed by [CZugnerGunnemann20].
- Reference:
Charpentier et al., “Posterior Networks: Uncertainty Estimation without OOD Samples via Density-Based Pseudo-Counts”, NeurIPS 2020. https://arxiv.org/abs/2006.09239
- Parameters:
alpha – Dirichlet concentration parameters, shape (B, C).
y – Ground-truth class labels, shape (B,).
entropy_weight – Weight of the entropy regularization term. Defaults to 1e-5 as used in the original paper.
reduction – Specifies the reduction to apply to the output. Can be ‘mean’ or ‘sum’. Defaults to ‘sum’ to align with the implementation in the original paper.
- Returns:
Scalar Posterior Networks loss averaged over the batch.