probly.train.evidential.torch.regularization_fn¶
- probly.train.evidential.torch.regularization_fn(alpha: Tensor, y: Tensor) Tensor[source]¶
Regularization term for Information Robust Dirichlet Networks.
Penalizes high Dirichlet concentration values for incorrect classes to encourage confident but well-calibrated predictions.
- Reference:
Tsiligkaridis, “Information Robust Dirichlet Networks for Predictive Uncertainty Estimation”, 2019. https://arxiv.org/abs/1910.04819
- Parameters:
alpha – Dirichlet concentration parameters, shape (B, K), must be > 0.
y – One-hot encoded class labels, shape (B, K).
- Returns:
Scalar regularization loss summed over classes and batch.
- Raises:
ValueError – If
alphaandyshapes do not match.