probly.train.evidential.torch.kl_dirichlet

probly.train.evidential.torch.kl_dirichlet(prior_alpha: Tensor, posterior_alpha: Tensor) Tensor[source]

Compute KL(Dir(alpha_p) || Dir(alpha_q)) for each batch item.

Used by Posterior Networks, Dirichlet Prior Networks, and PN-style in-distribution / out-of-distribution losses to compare Dirichlet distributions.

Parameters:
  • prior_alpha – Prior Dirichlet concentration parameters, shape (B, C).

  • posterior_alpha – Posterior Dirichlet concentration parameters, shape (B, C).

Returns:

KL divergence for each batch element, shape (B,)