probly.train.evidential.torch.rpn_distillation_loss

probly.train.evidential.torch.rpn_distillation_loss(rpn_params: tuple[Tensor, Tensor, Tensor, Tensor], mus: list[Tensor], variances: list[Tensor]) Tensor[source]

Compute the distillation loss for Regression Prior Networks (RPN).

This loss measures how well the RPN’s Normal-Wishart distribution matches the empirical ensemble distributions (mu_k, var_k).

Parameters:
  • rpn_params – The RPN output parameters (m, l_precision, kappa, nu).

  • mus – Ensemble predicted means.

  • variances – Ensemble predicted variances.

Returns:

Scalar loss value.