Calibration using Temperature Scaling¶
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
from probly.calibration import Temperature
from probly.evaluation.metrics import expected_calibration_error
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
Using device: mps
Load data¶
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.CIFAR10(root="~/datasets", train=True, download=True, transform=transforms)
train, cal = torch.utils.data.random_split(train, [0.8, 0.2])
test = torchvision.datasets.CIFAR10(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
cal_loader = DataLoader(cal, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)
Load neural network¶
net = torchvision.models.resnet18(pretrained=True)
net.fc = nn.Linear(512, 10, device=device)
net = net.to(device)
Train neural network¶
epochs = 5
optimizer = optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in tqdm(range(epochs)):
net.train()
running_loss = 0.0
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = net(inputs.to(device))
loss = criterion(outputs, targets.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")
# compute accuracy and expected calibration error on test set
net.eval()
with torch.no_grad():
outputs = torch.empty(0, device=device)
targets = torch.empty(0, device=device)
for inpt, target in tqdm(test_loader):
outputs = torch.cat((outputs, net(inpt.to(device))), dim=0)
targets = torch.cat((targets, target.to(device)), dim=0)
outputs = F.softmax(outputs, dim=1)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")
20%|██ | 1/5 [00:15<01:00, 15.04s/it]
Epoch 1, Running loss: 0.9242285938019965
40%|████ | 2/5 [00:30<00:45, 15.03s/it]
Epoch 2, Running loss: 0.562462279561219
60%|██████ | 3/5 [00:45<00:30, 15.05s/it]
Epoch 3, Running loss: 0.4154979696699009
80%|████████ | 4/5 [01:00<00:15, 15.03s/it]
Epoch 4, Running loss: 0.3231563491236632
100%|██████████| 5/5 [01:15<00:00, 15.03s/it]
Epoch 5, Running loss: 0.25008974437880666
100%|██████████| 40/40 [00:01<00:00, 26.32it/s]
Accuracy: 0.7667
Expected Calibration Error: 0.11227427605837584
Use the temperature scaling class and fit temperature using the calibration set¶
model = Temperature(net).to(device)
model.train()
model.fit(cal_loader, learning_rate=0.01, max_iter=100)
100%|██████████| 40/40 [00:01<00:00, 24.88it/s]
# compute accuracy and expected calibration error on test set after temperature scaling
model.eval()
with torch.no_grad():
outputs = torch.empty(0, device=device)
targets = torch.empty(0, device=device)
for inpt, target in tqdm(test_loader):
outputs = torch.cat((outputs, model.predict_pointwise(inpt.to(device))), dim=0)
targets = torch.cat((targets, target.to(device)), dim=0)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Softmax temperature: {model.temperature.item()}")
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")
100%|██████████| 40/40 [00:01<00:00, 25.39it/s]
Softmax temperature: 1.1454050540924072
Accuracy: 0.8066
Expected Calibration Error: 0.06147321143448353