Calibration with Label Relaxation

This notebook gives an example of how to use the Label Relaxation loss from https://ojs.aaai.org/index.php/AAAI/article/view/17041 to improve the calibration of a neural network. The example uses the FashionMNIST dataset and a LeNet architecture.

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.evaluation.metrics import expected_calibration_error
from probly.train.calibration.torch import LabelRelaxationLoss

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
Using device: mps

Load data

transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="~/datasets", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

Define neural network

class LeNet(nn.Module):
    """Implementation of a model with LeNet architecture.

    Attributes:
        conv1: nn.Module, first convolutional layer
        conv2: nn.Module, second convolutional layer
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        act: nn.Module, activation function
        max_pool: nn.Module, max pooling layer
    """

    def __init__(self) -> None:
        """Initializes an instance of the LeNet class."""
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the LeNet.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = LeNet().to(device)

Train the neural network using Label Relaxation

epochs = 10
optimizer = optim.Adam(net.parameters())
criterion = LabelRelaxationLoss(alpha=0.1)
for epoch in tqdm(range(epochs)):
    net.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")

# compute accuracy and expected calibration error on test set
net.eval()
with torch.no_grad():
    outputs = torch.empty(0, device=device)
    targets = torch.empty(0, device=device)
    for inpt, target in tqdm(test_loader):
        outputs = torch.cat((outputs, net(inpt.to(device))), dim=0)
        targets = torch.cat((targets, target.to(device)), dim=0)
outputs = F.softmax(outputs, dim=1)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")
 10%|█         | 1/10 [00:04<00:37,  4.16s/it]
Epoch 1, Running loss: 0.5653787719442489
 20%|██        | 2/10 [00:07<00:31,  3.88s/it]
Epoch 2, Running loss: 0.31656167754467496
 30%|███       | 3/10 [00:11<00:26,  3.84s/it]
Epoch 3, Running loss: 0.2688496538933287
 40%|████      | 4/10 [00:15<00:22,  3.83s/it]
Epoch 4, Running loss: 0.24309681076952752
 50%|█████     | 5/10 [00:19<00:18,  3.80s/it]
Epoch 5, Running loss: 0.2275508500794147
 60%|██████    | 6/10 [00:22<00:15,  3.77s/it]
Epoch 6, Running loss: 0.21468015015125275
 70%|███████   | 7/10 [00:26<00:11,  3.78s/it]
Epoch 7, Running loss: 0.2035592128621771
 80%|████████  | 8/10 [00:30<00:07,  3.77s/it]
Epoch 8, Running loss: 0.19679444286417455
 90%|█████████ | 9/10 [00:34<00:03,  3.76s/it]
Epoch 9, Running loss: 0.1865393954388639
100%|██████████| 10/10 [00:38<00:00,  3.81s/it]
Epoch 10, Running loss: 0.179943670395841
100%|██████████| 40/40 [00:00<00:00, 75.43it/s]
Accuracy: 0.8763
Expected Calibration Error: 0.04056899726688862