Out-of-Distribution Detection with an Ensemble¶
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm
from probly.evaluation.tasks import out_of_distribution_detection
from probly.quantification.classification import mutual_information
from probly.transformation import ensemble
Load data and create neural network¶
transforms = T.Compose([T.ToTensor(), torch.flatten])
train = torchvision.datasets.FashionMNIST(root="~/datasets/", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets/", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)
ood = torchvision.datasets.MNIST(root="~/datasets/", train=False, download=True, transform=transforms)
ood_loader = DataLoader(ood, batch_size=256, shuffle=False)
class Net(nn.Module):
"""Simple Neural Network class.
Attributes:
fc1: nn.Module, first fully connected layer
fc2: nn.Module, second fully connected layer
fc3: nn.Module, third fully connected layer
act: nn.Module, activation function
"""
def __init__(self) -> None:
"""Initializes an instance of the Net class."""
super().__init__()
self.fc1 = nn.Linear(784, 100)
self.fc2 = nn.Linear(100, 100)
self.fc3 = nn.Linear(100, 10)
self.act = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the neural network.
Args:
x: torch.Tensor, input data
Returns:
torch.Tensor, output data
"""
x = self.act(self.fc1(x))
x = self.act(self.fc2(x))
x = self.fc3(x)
return x
100%|██████████| 26.4M/26.4M [00:01<00:00, 13.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 2.76MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 11.6MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 15.1MB/s]
100%|██████████| 9.91M/9.91M [00:02<00:00, 4.38MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 281kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 1.74MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.01MB/s]
Make neural network a dropout model¶
ensemble = ensemble(Net(), 5)
Train each ensemble member as usual¶
criterion = nn.CrossEntropyLoss()
for model in tqdm(ensemble):
optimizer = optim.Adam(model.parameters())
for _ in range(10):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute accuracy on test set
correct = 0
total = 0
ensemble.eval()
for inputs, targets in test_loader:
outputs = []
for model in ensemble:
outputs.append(torch.softmax(model(inputs), dim=1))
outputs = torch.stack(outputs, dim=1).mean(dim=1)
correct += (outputs.argmax(1) == targets).sum().item()
total += targets.size(0)
print(f"Accuracy: {correct / total}")
100%|██████████| 5/5 [01:18<00:00, 15.70s/it]
Accuracy: 0.8794
Compute epistemic uncertainty for in-distribution (FashionMNIST) and out-of-distribution (MNIST) data¶
@torch.no_grad()
def torch_get_outputs(model: nn.Module, loader: DataLoader) -> torch.Tensor:
"""Generate outputs of the model given data from a loader.
Args:
model: nn.Module, model
loader: DataLoader, data loader
"""
outputs = []
for data, _ in loader:
output = []
for m in model:
output.append(torch.softmax(m(data), dim=1))
output = torch.stack(output, dim=1)
outputs.append(output)
outputs = torch.cat(outputs, dim=0)
return outputs
# get all outputs
outputs_id = torch_get_outputs(ensemble, test_loader)
outputs_ood = torch_get_outputs(ensemble, ood_loader)
outputs_id = outputs_id.numpy()
outputs_ood = outputs_ood.numpy()
# compute uncertainties
uncertainty_id = mutual_information(outputs_id)
uncertainty_ood = mutual_information(outputs_ood)
Do Out-of-Distribution task¶
# plot the uncertainties in a histogram
plt.hist(uncertainty_id, bins=50, alpha=0.5, label="In-Distribution")
plt.hist(uncertainty_ood, bins=50, alpha=0.5, label="Out-of-Distribution")
plt.legend()
plt.show()
auroc = out_of_distribution_detection(uncertainty_id, uncertainty_ood)
print(f"AUROC with FashionMNIST as iD and MNIST as OoD: {auroc:.3f}")
AUROC with FashionMNIST as iD and MNIST as OoD: 0.872