MC-Dropout uncertainty on a 2-D stream

A tiny PyTorch MLP trained one sample at a time on a 2-D classification stream. Wrapping the network with dropout() makes dropout layers active during inference, so a single representer() + quantify() call gives an MC-Dropout uncertainty decomposition on every step.

Halfway through the run we swap the class means of the data distribution. Epistemic uncertainty rises immediately after the swap because the network’s stochastic forward passes start to disagree on the unfamiliar inputs.

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn

from probly.method.dropout import dropout
from probly.quantification import quantify
from probly.representation.distribution.torch_categorical import TorchCategoricalDistribution
from probly.representer import representer

torch.manual_seed(0)
rng = np.random.default_rng(0)

Define a tiny network that returns a categorical distribution.

class TinyNet(nn.Module):
    """Two-layer MLP for binary classification."""

    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(16, 2),
        )

    def forward(self, x: torch.Tensor) -> TorchCategoricalDistribution:
        return TorchCategoricalDistribution(torch.softmax(self.net(x), dim=-1))


base = TinyNet()
mc_model = dropout(base, p=0.3)
opt = torch.optim.Adam(base.parameters(), lr=5e-2)

Build a 2-D classification stream that swaps its class means at t = 300.

N_STEPS = 600
DRIFT_T = 300


def stream_step(t: int) -> tuple[np.ndarray, int]:
    """Return one sample, with class means swapped after ``DRIFT_T``."""
    cls = int(rng.integers(0, 2))
    if t < DRIFT_T:
        mu = np.array([0.0, 0.0]) if cls == 0 else np.array([2.0, 2.0])
    else:
        mu = np.array([2.0, 2.0]) if cls == 0 else np.array([0.0, 0.0])
    return rng.normal(mu, 0.5).astype(np.float32), cls

Test-then-train loop. representer(mc_model, num_samples=10) runs 10 stochastic forward passes per step.

epi = np.zeros(N_STEPS)
alea = np.zeros(N_STEPS)
sampler = representer(mc_model, num_samples=10)

for t in range(N_STEPS):
    x_np, y = stream_step(t)
    x = torch.from_numpy(x_np).unsqueeze(0)

    with torch.no_grad():
        decomp = quantify(sampler.represent(x))
    epi[t] = float(decomp.epistemic)
    alea[t] = float(decomp.aleatoric)

    base.train()
    logits = base.net(x)
    loss = nn.functional.cross_entropy(logits, torch.tensor([y]))
    opt.zero_grad()
    loss.backward()
    opt.step()

Plot epistemic and aleatoric uncertainty over time.

window = 15
kernel = np.ones(window) / window
epi_s = np.convolve(epi, kernel, mode="same")
alea_s = np.convolve(alea, kernel, mode="same")

fig, ax = plt.subplots(figsize=(7, 3.2))
ax.plot(alea_s, label="aleatoric", color="#1f77b4", lw=1.2)
ax.plot(epi_s, label="epistemic", color="#d62728", lw=1.4)
ax.axvline(DRIFT_T, color="black", ls="--", lw=0.8, alpha=0.5, label="class swap")
ax.set_xlabel("step t")
ax.set_ylabel("uncertainty")
ax.set_title("MC-Dropout MLP on a 2-D stream with class swap")
ax.legend(frameon=False, fontsize=9)
fig.tight_layout()
plt.show()
MC-Dropout MLP on a 2-D stream with class swap

Total running time of the script: (0 minutes 4.635 seconds)

Gallery generated by Sphinx-Gallery