probly.train.evidential.torch.rpn_distillation_loss¶
- probly.train.evidential.torch.rpn_distillation_loss(rpn_params: tuple[Tensor, Tensor, Tensor, Tensor], mus: list[Tensor], variances: list[Tensor]) Tensor[source]¶
Compute the distillation loss for Regression Prior Networks (RPN).
This loss measures how well the RPN’s Normal-Wishart distribution matches the empirical ensemble distributions (mu_k, var_k).
- Parameters:
rpn_params – The RPN output parameters (m, l_precision, kappa, nu).
mus – Ensemble predicted means.
variances – Ensemble predicted variances.
- Returns:
Scalar loss value.