probly.train.credal.torch.intersection_probability_ce_loss

probly.train.credal.torch.intersection_probability_ce_loss(output: Tensor, targets: Tensor) Tensor[source]

Cross-entropy on the intersection probability of an interval-valued prediction.

Implements Eq. 14 of [WCM+24]. Splits the packed (B, 2C) interval output into (lower, upper), computes the intersection probability, and applies negative-log-likelihood against the targets. The probabilities are clamped to finfo(dtype).eps before the log to avoid -inf.

Parameters:
  • output – Packed (B, 2 * num_classes) tensor with the lower bounds in the first half and the upper bounds in the second.

  • targets – Ground-truth class indices of shape (B,).

Returns:

Scalar cross-entropy loss averaged over the batch.