probly.train.evidential.torch.pn_loss¶
- probly.train.evidential.torch.pn_loss(model: Module, x_in: Tensor, y_in: Tensor, x_ood: Tensor) Tensor[source]¶
Paired ID/OOD training loss for Dirichlet Prior Networks.
Combines KL divergence to sharp in-distribution targets and flat out-of-distribution targets, with an additional cross-entropy term for classification stability.
- Reference:
Malinin and Gales, “Predictive Uncertainty Estimation via Prior Networks”, NeurIPS 2018. https://arxiv.org/abs/1802.10501
- Parameters:
model – Network mapping inputs to Dirichlet concentration parameters.
x_in – In-distribution inputs, shape (B, …).
y_in – In-distribution class labels, shape (B,).
x_ood – Out-of-distribution inputs, shape (B_ood, …).
- Returns:
Scalar paired ID+OOD Prior Networks loss.