probly.train.evidential.torch.kl_dirichlet¶
- probly.train.evidential.torch.kl_dirichlet(prior_alpha: Tensor, posterior_alpha: Tensor) Tensor[source]¶
Compute KL(Dir(alpha_p) || Dir(alpha_q)) for each batch item.
Used by Posterior Networks, Dirichlet Prior Networks, and PN-style in-distribution / out-of-distribution losses to compare Dirichlet distributions.
- Parameters:
prior_alpha – Prior Dirichlet concentration parameters, shape (B, C).
posterior_alpha – Posterior Dirichlet concentration parameters, shape (B, C).
- Returns:
KL divergence for each batch element, shape (B,)