probly.method.duq¶
- probly.method.duq(base: Predictor[In, Out], centroid_size: int = 256, length_scale: float = 0.1, gamma: float = 0.999) DUQPredictor[In, Out][source]¶
Transform a model for Deterministic Uncertainty Quantification [vASTG20].
Replaces the original classification head (the last
nn.Linearlayer) with an RBF centroid head. For each class \(c\), a learnable projection \(W_c \in \mathbb{R}^{n \times d}\) maps the feature vector \(f_\theta(x) \in \mathbb{R}^d\) to an embedding \(z_c = W_c f_\theta(x)\). The kernel value \(K_c(x) = \exp\left(-\|z_c - e_c\|^2 / (2 n \sigma^2)\right)\) is computed against an EMA-updated class centroid \(e_c\). The predicted class is \(\arg\max_c K_c(x)\) and the uncertainty score is \(1 - \max_c K_c(x)\).The transformed predictor is intended to be trained from scratch with the binary cross-entropy loss on the kernel values and a two-sided gradient penalty on the inputs, as in the reference implementation. Class centroids must be updated each step via
TorchDUQPredictor.update_centroids().- Parameters:
base – Base classification model to be transformed.
centroid_size – Embedding dimension \(n\) of the per-class projections.
length_scale – RBF kernel length scale \(\\sigma\).
gamma – Exponential moving-average decay for the class centroids.
- Returns:
The transformed DUQ predictor.