Bayesian Neural Networks

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.train.bayesian.torch import ELBOLoss, collect_kl_divergence
from probly.transformation import bayesian

Prepare the data

transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.FashionMNIST(root="~/datasets", train=True, download=True, transform=transforms)
test = torchvision.datasets.FashionMNIST(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

Define a simple neural network and make it Bayesian

class LeNet(nn.Module):
    """Implementation of a model with LeNet architecture.

    Attributes:
        conv1: nn.Module, first convolutional layer
        conv2: nn.Module, second convolutional layer
        fc1: nn.Module, first fully connected layer
        fc2: nn.Module, second fully connected layer
        act: nn.Module, activation function
        max_pool: nn.Module, max pooling layer
    """

    def __init__(self) -> None:
        """Initializes an instance of the LeNet class."""
        super().__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        self.act = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the LeNet.

        Args:
            x: torch.Tensor, input data
        Returns:
            torch.Tensor, output data
        """
        x = self.act(self.max_pool(self.conv1(x)))
        x = self.act(self.max_pool(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x


net = LeNet()
model = bayesian(net)

Train the Bayesian neural network using the ELBO loss

epochs = 20
optimizer = optim.Adam(model.parameters())
criterion = ELBOLoss(1e-5)
for epoch in tqdm(range(epochs)):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        kl = collect_kl_divergence(model)
        loss = criterion(outputs, targets, kl)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}, KL: {kl.item()}")

# compute accuracy on test set
correct = 0
total = 0
model.eval()
for inputs, targets in test_loader:
    outputs = model(inputs)
    correct += (outputs.argmax(1) == targets).sum()
    total += targets.size(0)
print(f"Accuracy: {correct / total}")
  5%|▌         | 1/20 [00:10<03:20, 10.53s/it]
Epoch 1, Running loss: 2.4207534445093035, KL: 94045.21875
 10%|█         | 2/20 [00:20<03:04, 10.24s/it]
Epoch 2, Running loss: 1.7625010044016736, KL: 94134.6875
 15%|█▌        | 3/20 [00:30<02:52, 10.18s/it]
Epoch 3, Running loss: 1.6143848358316624, KL: 94217.8046875
 20%|██        | 4/20 [00:40<02:42, 10.13s/it]
Epoch 4, Running loss: 1.544742205295157, KL: 94299.0
 25%|██▌       | 5/20 [00:50<02:31, 10.10s/it]
Epoch 5, Running loss: 1.4946847834485641, KL: 94383.8046875
 30%|███       | 6/20 [01:00<02:21, 10.13s/it]
Epoch 6, Running loss: 1.4544967428166815, KL: 94470.3359375
 35%|███▌      | 7/20 [01:10<02:11, 10.09s/it]
Epoch 7, Running loss: 1.4172436328644449, KL: 94555.3984375
 40%|████      | 8/20 [01:21<02:01, 10.09s/it]
Epoch 8, Running loss: 1.3891127028363817, KL: 94646.71875
 45%|████▌     | 9/20 [01:31<01:53, 10.33s/it]
Epoch 9, Running loss: 1.3701187615698955, KL: 94740.890625
 50%|█████     | 10/20 [01:42<01:43, 10.31s/it]
Epoch 10, Running loss: 1.3505153067568516, KL: 94835.7421875
 55%|█████▌    | 11/20 [01:52<01:32, 10.27s/it]
Epoch 11, Running loss: 1.335881043495016, KL: 94929.421875
 60%|██████    | 12/20 [02:02<01:22, 10.31s/it]
Epoch 12, Running loss: 1.3261284645567548, KL: 95024.1953125
 65%|██████▌   | 13/20 [02:13<01:12, 10.34s/it]
Epoch 13, Running loss: 1.3145434059995287, KL: 95121.8203125
 70%|███████   | 14/20 [02:23<01:02, 10.36s/it]
Epoch 14, Running loss: 1.303467125588275, KL: 95216.2890625
 75%|███████▌  | 15/20 [02:34<00:52, 10.42s/it]
Epoch 15, Running loss: 1.2922741068170425, KL: 95312.921875
 80%|████████  | 16/20 [02:44<00:41, 10.43s/it]
Epoch 16, Running loss: 1.2869641364888942, KL: 95409.4765625
 85%|████████▌ | 17/20 [02:55<00:31, 10.44s/it]
Epoch 17, Running loss: 1.2798084436578954, KL: 95506.921875
 90%|█████████ | 18/20 [03:05<00:20, 10.40s/it]
Epoch 18, Running loss: 1.2714758502676131, KL: 95601.9375
 95%|█████████▌| 19/20 [03:15<00:10, 10.36s/it]
Epoch 19, Running loss: 1.2655935074420686, KL: 95705.59375
100%|██████████| 20/20 [03:26<00:00, 10.30s/it]
Epoch 20, Running loss: 1.2595517417217823, KL: 95808.8359375

Accuracy: 0.8776000142097473