probly.train.credal.torch.intersection_probability_ce_loss¶
- probly.train.credal.torch.intersection_probability_ce_loss(output: Tensor, targets: Tensor) Tensor[source]¶
Cross-entropy on the intersection probability of an interval-valued prediction.
Implements Eq. 14 of [WCM+24]. Splits the packed
(B, 2C)interval output into(lower, upper), computes the intersection probability, and applies negative-log-likelihood against the targets. The probabilities are clamped tofinfo(dtype).epsbefore the log to avoid-inf.- Parameters:
output – Packed
(B, 2 * num_classes)tensor with the lower bounds in the first half and the upper bounds in the second.targets – Ground-truth class indices of shape
(B,).
- Returns:
Scalar cross-entropy loss averaged over the batch.