probly.train.evidential.torch.rpn_loss¶
- probly.train.evidential.torch.rpn_loss(model: Module, x_id: Tensor, y_id: Tensor, x_ood: Tensor, lam_der: float = 0.01, lam_rpn: float = 50.0) Tensor[source]¶
Paired in-distribution and out-of-distribution loss for Regression Prior Networks.
Computes the Regression Prior Network (RPN) training objective using paired in-distribution (ID) and out-of-distribution (OOD) mini-batches. The loss combines a supervised Deep Evidential Regression (DER) term on ID data with a KL regularization term that pushes OOD predictions back toward the Normal-Gamma prior.
- Reference:
Malinin et al., “Regression Prior Networks”, NeurIPS 2020. https://arxiv.org/abs/2006.11590
- Parameters:
model – Regression model returning (mu, kappa, alpha, beta) for each input.
x_id – In-distribution inputs, shape (B_id, …).
y_id – In-distribution regression targets, shape (B_id,) or compatible.
x_ood – Out-of-distribution inputs, shape (B_ood, …).
lam_der – Weight of the DER evidence regularization term.
lam_rpn – Weight of the RPN prior-matching KL term.
- Returns:
Scalar paired ID+OOD Regression Prior Network loss.