Calibration with Label Relaxation¶
This notebook gives an example of how to use the Label Relaxation loss from https://ojs.aaai.org/index.php/AAAI/article/view/17041 to improve the calibration of a neural network. The example uses the FashionMNIST dataset and a LeNet architecture.
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
from probly.evaluation.metrics import expected_calibration_error
from probly.train.calibration.torch import LabelRelaxationLoss
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
Using device: mps
Load data¶
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="~/datasets", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)
Define neural network¶
class LeNet(nn.Module):
"""Implementation of a model with LeNet architecture.
Attributes:
conv1: nn.Module, first convolutional layer
conv2: nn.Module, second convolutional layer
fc1: nn.Module, first fully connected layer
fc2: nn.Module, second fully connected layer
act: nn.Module, activation function
max_pool: nn.Module, max pooling layer
"""
def __init__(self) -> None:
"""Initializes an instance of the LeNet class."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, 10)
self.act = nn.ReLU()
self.max_pool = nn.MaxPool2d(2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the LeNet.
Args:
x: torch.Tensor, input data
Returns:
torch.Tensor, output data
"""
x = self.act(self.max_pool(self.conv1(x)))
x = self.act(self.max_pool(self.conv2(x)))
x = torch.flatten(x, 1)
x = self.act(self.fc1(x))
x = self.fc2(x)
return x
net = LeNet().to(device)
Train the neural network using Label Relaxation¶
epochs = 10
optimizer = optim.Adam(net.parameters())
criterion = LabelRelaxationLoss(alpha=0.1)
for epoch in tqdm(range(epochs)):
net.train()
running_loss = 0.0
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = net(inputs.to(device))
loss = criterion(outputs, targets.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")
# compute accuracy and expected calibration error on test set
net.eval()
with torch.no_grad():
outputs = torch.empty(0, device=device)
targets = torch.empty(0, device=device)
for inpt, target in tqdm(test_loader):
outputs = torch.cat((outputs, net(inpt.to(device))), dim=0)
targets = torch.cat((targets, target.to(device)), dim=0)
outputs = F.softmax(outputs, dim=1)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")
10%|█ | 1/10 [00:04<00:37, 4.16s/it]
Epoch 1, Running loss: 0.5653787719442489
20%|██ | 2/10 [00:07<00:31, 3.88s/it]
Epoch 2, Running loss: 0.31656167754467496
30%|███ | 3/10 [00:11<00:26, 3.84s/it]
Epoch 3, Running loss: 0.2688496538933287
40%|████ | 4/10 [00:15<00:22, 3.83s/it]
Epoch 4, Running loss: 0.24309681076952752
50%|█████ | 5/10 [00:19<00:18, 3.80s/it]
Epoch 5, Running loss: 0.2275508500794147
60%|██████ | 6/10 [00:22<00:15, 3.77s/it]
Epoch 6, Running loss: 0.21468015015125275
70%|███████ | 7/10 [00:26<00:11, 3.78s/it]
Epoch 7, Running loss: 0.2035592128621771
80%|████████ | 8/10 [00:30<00:07, 3.77s/it]
Epoch 8, Running loss: 0.19679444286417455
90%|█████████ | 9/10 [00:34<00:03, 3.76s/it]
Epoch 9, Running loss: 0.1865393954388639
100%|██████████| 10/10 [00:38<00:00, 3.81s/it]
Epoch 10, Running loss: 0.179943670395841
100%|██████████| 40/40 [00:00<00:00, 75.43it/s]
Accuracy: 0.8763
Expected Calibration Error: 0.04056899726688862