probly.train.evidential.torch.pn_loss

probly.train.evidential.torch.pn_loss(model: Module, x_in: Tensor, y_in: Tensor, x_ood: Tensor) Tensor[source]

Paired ID/OOD training loss for Dirichlet Prior Networks.

Combines KL divergence to sharp in-distribution targets and flat out-of-distribution targets, with an additional cross-entropy term for classification stability.

Reference:

Malinin and Gales, “Predictive Uncertainty Estimation via Prior Networks”, NeurIPS 2018. https://arxiv.org/abs/1802.10501

Parameters:
  • model – Network mapping inputs to Dirichlet concentration parameters.

  • x_in – In-distribution inputs, shape (B, …).

  • y_in – In-distribution class labels, shape (B,).

  • x_ood – Out-of-distribution inputs, shape (B_ood, …).

Returns:

Scalar paired ID+OOD Prior Networks loss.