Source code for probly.conformal_prediction.scores.saps.torch

"""Torch for SAPS."""

from __future__ import annotations

import torch

from .common import register


[docs] def saps_score_torch( probs: torch.Tensor, lambda_val: float, u: torch.Tensor, ) -> torch.Tensor: """Compute SAPS nonconformity score for torch tensors. Args: probs: 1D tensor with softmax probabilities. lambda_val: lambda value for SAPS. u: optional random value in [0,1). Returns: torch.Tensor: SAPS nonconformity score. """ if not isinstance(u, torch.Tensor): u = torch.tensor(u, device=probs.device, dtype=probs.dtype) # convert to torch tensors probs = torch.asarray(probs, dtype=torch.float) u = torch.asarray(u, dtype=torch.float) # get max probabilities for each sample max_probs = torch.max(probs, dim=1, keepdim=True).values # get ranks for each label, argsort along axis=1 in descending order sort_idx = torch.argsort(-probs, dim=1) # find the rank (1-based) of each label # compare each position in sorted_indices with the corresponding label ranks_zero_based = torch.argsort(sort_idx, dim=1) ranks = ranks_zero_based + 1 # +1 for 1-based rank scores = torch.where(ranks == 1, u * max_probs, max_probs + (ranks - 2 + u) * lambda_val) return torch.asarray(scores, dtype=torch.float)
register(torch.Tensor, saps_score_torch)