probly.train.evidential.torch.postnet_loss

probly.train.evidential.torch.postnet_loss(alpha: Tensor, y: Tensor, entropy_weight: float = 1e-05, reduction: str = 'sum') Tensor[source]

Posterior Networks (PostNet) classification loss.

Implements the expected cross-entropy loss with an entropy regularizer as proposed by [CZugnerGunnemann20].

Reference:

Charpentier et al., “Posterior Networks: Uncertainty Estimation without OOD Samples via Density-Based Pseudo-Counts”, NeurIPS 2020. https://arxiv.org/abs/2006.09239

Parameters:
  • alpha – Dirichlet concentration parameters, shape (B, C).

  • y – Ground-truth class labels, shape (B,).

  • entropy_weight – Weight of the entropy regularization term. Defaults to 1e-5 as used in the original paper.

  • reduction – Specifies the reduction to apply to the output. Can be ‘mean’ or ‘sum’. Defaults to ‘sum’ to align with the implementation in the original paper.

Returns:

Scalar Posterior Networks loss averaged over the batch.