probly.train.evidential.torch.rpn_loss

probly.train.evidential.torch.rpn_loss(model: Module, x_id: Tensor, y_id: Tensor, x_ood: Tensor, lam_der: float = 0.01, lam_rpn: float = 50.0) Tensor[source]

Paired in-distribution and out-of-distribution loss for Regression Prior Networks.

Computes the Regression Prior Network (RPN) training objective using paired in-distribution (ID) and out-of-distribution (OOD) mini-batches. The loss combines a supervised Deep Evidential Regression (DER) term on ID data with a KL regularization term that pushes OOD predictions back toward the Normal-Gamma prior.

Reference:

Malinin et al., “Regression Prior Networks”, NeurIPS 2020. https://arxiv.org/abs/2006.11590

Parameters:
  • model – Regression model returning (mu, kappa, alpha, beta) for each input.

  • x_id – In-distribution inputs, shape (B_id, …).

  • y_id – In-distribution regression targets, shape (B_id,) or compatible.

  • x_ood – Out-of-distribution inputs, shape (B_ood, …).

  • lam_der – Weight of the DER evidence regularization term.

  • lam_rpn – Weight of the RPN prior-matching KL term.

Returns:

Scalar paired ID+OOD Regression Prior Network loss.