probly.method.duq

probly.method.duq(base: Predictor[In, Out], centroid_size: int = 256, length_scale: float = 0.1, gamma: float = 0.999) DUQPredictor[In, Out][source]

Transform a model for Deterministic Uncertainty Quantification [vASTG20].

Replaces the original classification head (the last nn.Linear layer) with an RBF centroid head. For each class \(c\), a learnable projection \(W_c \in \mathbb{R}^{n \times d}\) maps the feature vector \(f_\theta(x) \in \mathbb{R}^d\) to an embedding \(z_c = W_c f_\theta(x)\). The kernel value \(K_c(x) = \exp\left(-\|z_c - e_c\|^2 / (2 n \sigma^2)\right)\) is computed against an EMA-updated class centroid \(e_c\). The predicted class is \(\arg\max_c K_c(x)\) and the uncertainty score is \(1 - \max_c K_c(x)\).

The transformed predictor is intended to be trained from scratch with the binary cross-entropy loss on the kernel values and a two-sided gradient penalty on the inputs, as in the reference implementation. Class centroids must be updated each step via TorchDUQPredictor.update_centroids().

Parameters:
  • base – Base classification model to be transformed.

  • centroid_size – Embedding dimension \(n\) of the per-class projections.

  • length_scale – RBF kernel length scale \(\\sigma\).

  • gamma – Exponential moving-average decay for the class centroids.

Returns:

The transformed DUQ predictor.