Out-of-Distribution Detection with an Ensemble

import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.evaluation.tasks import out_of_distribution_detection
from probly.quantification.classification import mutual_information
from probly.transformation import ensemble

Load data and create neural network

transforms = T.Compose([T.ToTensor(), torch.flatten])

train = torchvision.datasets.FashionMNIST(root="~/datasets/", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets/", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

ood = torchvision.datasets.MNIST(root="~/datasets/", train=False, download=True, transform=transforms)
ood_loader = DataLoader(ood, batch_size=256, shuffle=False)


class Net(nn.Module):
    """Simple Neural Network class.

    Attributes:
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        fc3: nn.Module, third fully connected layer
        act: nn.Module, activation function
    """

    def __init__(self) -> None:
        """Initializes an instance of the Net class."""
        super().__init__()
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 10)
        self.act = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the neural network.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x
100%|██████████| 26.4M/26.4M [00:01<00:00, 13.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 2.76MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 11.6MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 15.1MB/s]
100%|██████████| 9.91M/9.91M [00:02<00:00, 4.38MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 281kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 1.74MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.01MB/s]

Make neural network a dropout model

ensemble = ensemble(Net(), 5)

Train each ensemble member as usual

criterion = nn.CrossEntropyLoss()
for model in tqdm(ensemble):
    optimizer = optim.Adam(model.parameters())
    for _ in range(10):
        model.train()
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

# compute accuracy on test set
correct = 0
total = 0
ensemble.eval()
for inputs, targets in test_loader:
    outputs = []
    for model in ensemble:
        outputs.append(torch.softmax(model(inputs), dim=1))
    outputs = torch.stack(outputs, dim=1).mean(dim=1)
    correct += (outputs.argmax(1) == targets).sum().item()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")
100%|██████████| 5/5 [01:18<00:00, 15.70s/it]
Accuracy: 0.8794

Compute epistemic uncertainty for in-distribution (FashionMNIST) and out-of-distribution (MNIST) data

@torch.no_grad()
def torch_get_outputs(model: nn.Module, loader: DataLoader) -> torch.Tensor:
    """Generate outputs of the model given data from a loader.

    Args:
        model: nn.Module, model
        loader: DataLoader, data loader
    """
    outputs = []
    for data, _ in loader:
        output = []
        for m in model:
            output.append(torch.softmax(m(data), dim=1))
        output = torch.stack(output, dim=1)
        outputs.append(output)
    outputs = torch.cat(outputs, dim=0)
    return outputs


# get all outputs
outputs_id = torch_get_outputs(ensemble, test_loader)
outputs_ood = torch_get_outputs(ensemble, ood_loader)
outputs_id = outputs_id.numpy()
outputs_ood = outputs_ood.numpy()

# compute uncertainties
uncertainty_id = mutual_information(outputs_id)
uncertainty_ood = mutual_information(outputs_ood)

Do Out-of-Distribution task

# plot the uncertainties in a histogram
plt.hist(uncertainty_id, bins=50, alpha=0.5, label="In-Distribution")
plt.hist(uncertainty_ood, bins=50, alpha=0.5, label="Out-of-Distribution")
plt.legend()
plt.show()

auroc = out_of_distribution_detection(uncertainty_id, uncertainty_ood)
print(f"AUROC with FashionMNIST as iD and MNIST as OoD: {auroc:.3f}")
../../_images/c0be30d7c0cf56ba1802f221150dc2b9a3e4956ceef45b177d067dc80e5a4349.png
AUROC with FashionMNIST as iD and MNIST as OoD: 0.872