probly.train.evidential.torch.regularization_fn

probly.train.evidential.torch.regularization_fn(alpha: Tensor, y: Tensor) Tensor[source]

Regularization term for Information Robust Dirichlet Networks.

Penalizes high Dirichlet concentration values for incorrect classes to encourage confident but well-calibrated predictions.

Reference:

Tsiligkaridis, “Information Robust Dirichlet Networks for Predictive Uncertainty Estimation”, 2019. https://arxiv.org/abs/1910.04819

Parameters:
  • alpha – Dirichlet concentration parameters, shape (B, K), must be > 0.

  • y – One-hot encoded class labels, shape (B, K).

Returns:

Scalar regularization loss summed over classes and batch.

Raises:

ValueError – If alpha and y shapes do not match.