Classification Conformal Prediction — PyTorch

Demonstrate all four classification non-conformity scores (lac_score(), APSScore, RAPSScore, SAPSScore) using a small Module on the Iris dataset.

Each score uses its own conformal wrapper. During calibration the conformal quantile is computed; after calibration representer() returns a boolean inclusion mask (the conformal prediction set).

from __future__ import annotations

import numpy as np
from sklearn.ensemble import RandomForestClassifier
import torch
from torch import nn
from sklearn.datasets import load_digits
from sklearn.model_selection import KFold, train_test_split

from probly.calibrator import calibrate
from probly.metrics._common import average_set_size, empirical_coverage_classification
from probly.method.conformal import conformal_aps, conformal_lac, conformal_raps, conformal_saps
from probly.predictor import LogitClassifier
from probly.representer import representer

torch.manual_seed(42)
<torch._C.Generator object at 0x7f1e3310e950>

Data preparation

Define and train the model

class SimpleNet(nn.Module, LogitClassifier):
    """Two-layer classifier."""

    def __init__(self, in_features: int, num_classes: int) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, 16)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(16, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.relu(self.fc1(x)))


model = SimpleNet(64, 10)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

model.train()
for _ in range(200):
    optimizer.zero_grad()
    loss_fn(model(X_train_t), y_train_t).backward()
    optimizer.step()
model.eval()
SimpleNet(
  (fc1): Linear(in_features=64, out_features=16, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=16, out_features=10, bias=True)
)

LAC score

calibrated_model = calibrate(conformal_lac(model), ALPHA, y_calib_t, X_calib_t)
output = representer(calibrated_model).predict(X_test_t)
lac_cov = empirical_coverage_classification(output, y_test_t)
lac_size = average_set_size(output)
print(f"LAC  — coverage: {lac_cov:.3f}, avg set size: {lac_size:.3f}")
LAC  — coverage: 0.942, avg set size: 1.039

APS score

calibrated_model = calibrate(conformal_aps(model, randomized=True), ALPHA, y_calib_t, X_calib_t)
output = representer(calibrated_model).predict(X_test_t)
aps_cov = empirical_coverage_classification(output, y_test_t)
aps_size = average_set_size(output)
print(f"APS  — coverage: {aps_cov:.3f}, avg set size: {aps_size:.3f}")
APS  — coverage: 0.942, avg set size: 1.117

RAPS score

calibrated_model = calibrate(conformal_raps(model, randomized=True, lambda_reg=0.1, k_reg=0), ALPHA, y_calib_t, X_calib_t)
output = representer(calibrated_model).predict(X_test_t)
raps_cov = empirical_coverage_classification(output, y_test_t)
raps_size = average_set_size(output)
print(f"RAPS — coverage: {raps_cov:.3f}, avg set size: {raps_size:.3f}")
RAPS — coverage: 0.933, avg set size: 1.044

SAPS score

calibrated_model = calibrate(conformal_saps(model, randomized=True, lambda_val=0.1), ALPHA, y_calib_t, X_calib_t)
output = representer(calibrated_model).predict(X_test_t)
saps_cov = empirical_coverage_classification(output, y_test_t)
saps_size = average_set_size(output)
print(f"SAPS — coverage: {saps_cov:.3f}, avg set size: {saps_size:.3f}")
SAPS — coverage: 0.942, avg set size: 1.256

Summary (Averaged over multiple runs)

res = {
    "LAC": [],
    "APS": [],
    "RAPS": [],
    "SAPS": [],
}
for fold, (train_idx, test_idx) in enumerate(KFold(n_splits=5, shuffle=True, random_state=42).split(X)):
    X_train, y_train = X[train_idx], y[train_idx]
    X_test, y_test = X[test_idx], y[test_idx]
    X_train, X_calib, y_train, y_calib = train_test_split(X_train, y_train, test_size=0.25, random_state=fold)

    fold_model = RandomForestClassifier(random_state=fold)
    fold_model.fit(X_train, y_train)
    for name, conformal_func in [("LAC", conformal_lac), ("APS", conformal_aps), ("RAPS", conformal_raps), ("SAPS", conformal_saps)]:
        calibrated_model = calibrate(conformal_func(fold_model), ALPHA, y_calib, X_calib)
        output = representer(calibrated_model).predict(X_test)
        cov = empirical_coverage_classification(output, y_test)
        size = average_set_size(output)
        res[name].append((cov, size))

for name, vals in res.items():
    covs, sizes = zip(*vals)
    print(f"{name} — coverage: {np.mean(covs):.3f} ± {np.std(covs):.3f}, avg set size: {np.mean(sizes):.3f} ± {np.std(sizes):.3f}")
LAC — coverage: 0.938 ± 0.028, avg set size: 0.950 ± 0.031
APS — coverage: 0.956 ± 0.010, avg set size: 1.876 ± 0.093
RAPS — coverage: 0.953 ± 0.018, avg set size: 1.305 ± 0.056
SAPS — coverage: 0.947 ± 0.027, avg set size: 2.018 ± 0.184

Total running time of the script: (0 minutes 1.665 seconds)

Gallery generated by Sphinx-Gallery