Batch Ensemble Networks

In this notebook, we:

A) Explain the idea behind Batch Ensembles.

B) Create a small Multi-Layer-Perceptron (MLP).

D) Train the MLP as an Ensemble and BatchEnsemble network on CIFAR10.

E) Compare accuracy and speed.

A) Introduction: What are Batch Ensembles?

Batch Ensembles are a way to efficiently approximate an ensemble of neural networks. Traditional ensembles require training and storing multiple independent networks, which is memory and computation expensive.

Key ideas:

  • Use shared base weights (and biases) for all ensemble members.

  • Introduce rank-1 multiplicative factors for each member.

  • Much faster and memory-efficient than classic ensembles.

Mathematically the classic forward of

$$y_i = W \circ x + b$$

transforms to

$$y_i = (W \circ (x \circ s_i^T)) \circ r_i+ b$$

Where $r_i, s_i$ are the rank-1 vectors for ensemble member $i$, and $\circ$ denotes element-wise multiplication.

What does the transformation do?

The parameters

  • num_members: The number of ensemble members to create.

  • s_mean: The mean used to initialize the input modulation factor s.

  • s_std: The standard deviation used to initialize the input modulation factor s.

  • r_mean: The mean used to initialize the output modulation factor r.

  • r_std: The standard deviation used to initialize the output modulation factor r.

The layers

With these parameters we can transform:

  • Linear layer into BatchEnsembleLinear layer

  • Conv2d layer into BatchEnsembleConv2d layer

The transformation keeps the dimensions and base weights, while adding rank-1 factors:

  • s: scales input-dimension per member

  • r: scales output-dimension per member

The base weights (weight) and bias (bias) are shared across all members, keeping memory usage minimal. The differences between members arise solely from their individual scaling factorys s and r.

B) Setup of the MLP

Standard Imports and Pytorch Setup

import time

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
Using device: cpu

Import CIFAR10 Dataset

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        # add more transforms if desired
    ],
)

train_data = CIFAR10(root="./data", train=True, transform=transform, download=True)
val_data = CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)

print(f"Train samples: {len(train_data)},  Val samples: {len(val_data)}")
Train samples: 50000,  Val samples: 10000

The MLP Class

We create a MLP inherting basic functionality from the nn.Module parent class. The MLP has to hidden layers utilizing the ReLU activation function.

class MLP(nn.Module):
    def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10) -> None:
        """Initialize the MLP model with two hidden layers.

        Args:
            in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
            hidden (int): Number of neurons in the hidden layers. Default is 128.
            out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP model.

        Before passing the input through the network, it flattens the input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, out_dim).
        """
        x = x.view(x.size(0), -1)
        return self.net(x)

The Ensemble MLP

from probly.transformation.ensemble import ensemble

in_dim = 3 * 32 * 32
hidden = 128
num_members = 5

ensemble_mlp = ensemble(
    base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
    num_members=num_members,
)

The BatchEnsemble MLP

from probly.transformation import batchensemble

batch_ensemble_mlp = batchensemble(
    base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
    num_members=num_members,
)

Let’s compare the different models now.

We start with a comparison of the base MLP and the BatchEnsemble MLP:

print(f"Base MLP:\n{MLP(in_dim=in_dim, hidden=hidden, out_dim=10)}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")
Base MLP:
MLP(
  (net): Sequential(
    (0): Linear(in_features=3072, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)

BatchEnsemble MLP:
MLP(
  (net): Sequential(
    (0): BatchEnsembleLinear()
    (1): ReLU()
    (2): BatchEnsembleLinear()
    (3): ReLU()
    (4): BatchEnsembleLinear()
  )
)

Then we compare the Ensemble MLP and the BatchEnsemble MLP:

print(f"Ensemble MLP:\n{ensemble_mlp}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")
Ensemble MLP:
ModuleList(
  (0-4): 5 x MLP(
    (net): Sequential(
      (0): Linear(in_features=3072, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=10, bias=True)
    )
  )
)

BatchEnsemble MLP:
MLP(
  (net): Sequential(
    (0): BatchEnsembleLinear()
    (1): ReLU()
    (2): BatchEnsembleLinear()
    (3): ReLU()
    (4): BatchEnsembleLinear()
  )
)

C) Training

Training Methods

While there is currently no training functionality implemented in probly we define the training methods below.

Base Training Method

def train_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_function: nn.CrossEntropyLoss,
    train_loader: DataLoader,
    epochs: int = 10,
    num_members: int | None = None,
) -> nn.Module:
    for epoch in range(epochs):
        t0 = time.perf_counter()
        total_loss = 0.0
        model.train()
        for xb, yb in train_loader:
            x = xb.to(device).float()
            y = yb.to(device).long()
            optimizer.zero_grad()
            out = model(x)

            if isinstance(num_members, int) and num_members > 0:
                loss = 0.0
                for e in range(num_members):
                    loss += loss_function(out[e], y)
                loss = loss / num_members
            else:
                # fallback to standard loss computation
                loss = loss_function(out, y)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()
        avg_loss = total_loss / len(train_loader)
        t1 = time.perf_counter()
        print(f"Epoch {epoch + 1}/{epochs} trained in {t1 - t0} seconds.")
        print(f"> Loss: {avg_loss}")
    return model
def train_ensemble(
    ensemble: MLP,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.ModuleList:
    model = nn.ModuleList()

    for i, member in enumerate(ensemble):
        print(f"\nTraining ensemble member {i + 1}/{len(ensemble)}")
        member_i = member.to(device)
        optimizer = optim.Adam(member_i.parameters(), lr=lr)
        train_model(
            member_i,
            optimizer=optimizer,
            loss_function=nn.CrossEntropyLoss(),
            train_loader=train_loader,
            epochs=epochs,
        )
        model.append(member_i)
    return model


def train_batchensemble(
    base_cls: MLP,
    num_members: int,
    train_loader: DataLoader,
    epochs: int = 10,
    lr: float = 1e-3,
) -> nn.Module:
    model = base_cls.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model = train_model(
        model,
        optimizer=optimizer,
        loss_function=nn.CrossEntropyLoss(),
        train_loader=train_loader,
        epochs=epochs,
        num_members=num_members,
    )

    return model
epochs = 1
lr = 1e-3

t0_batch_ensemble = time.perf_counter()
trained_batch_ensemble = train_batchensemble(
    base_cls=batch_ensemble_mlp,
    num_members=num_members,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_batch_ensemble = time.perf_counter()
print(f"\nTrained BatchEnsemble model of size {num_members} in {t1_batch_ensemble - t0_batch_ensemble:.2f}s")
t0_ensemble = time.perf_counter()
trained_ensemble = train_ensemble(
    ensemble=ensemble_mlp,
    train_loader=train_loader,
    epochs=epochs,
    lr=lr,
)
t1_ensemble = time.perf_counter()
print(f"\nTrained classical ensemble of size {num_members} in {t1_ensemble - t0_ensemble:.2f}s")
class Evaluator:
    def __init__(self, data_loader: torch.utils.data.DataLoader, device: str) -> None:
        """Initialize the Evaluator with a data loader and device.

        Args:
            data_loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
            device (str): Device to run the evaluation on ('cpu' or 'cuda').
        """
        self.data_loader = data_loader
        self.device = device

    def _setup(self) -> None:
        self.correct = 0
        self.total = 0
        self.member_predictions = []

    def evaluate_batchensemble(self, model: nn.Module, num_members: int) -> tuple[float, torch.Tensor]:
        """Evaluate a BatchEnsemble model."""
        self._setup()
        model.to(self.device)
        model.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                out = model(x)  # [E, B, out_dim]
                preds = torch.argmax(out, dim=2)  # [E, B]

                self.correct += (preds == y.unsqueeze(0)).sum().item()
                self.total += y.size(0) * num_members
                self.member_predictions.append(preds.cpu())

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)

        return accuracy, all_member_preds

    def evaluate_classical_ensemble(self, models: nn.ModuleList) -> tuple[float, torch.Tensor]:
        """Evaluate a classical ensemble of models."""
        self._setup()
        for m in models:
            m.to(self.device)
            m.eval()

        with torch.no_grad():
            for xb, yb in self.data_loader:
                x = xb.to(self.device).float()
                y = yb.to(self.device).long()

                batch_member_preds = []
                for m in models:
                    out = m(x)  # [B, out_dim]
                    preds = torch.argmax(out, dim=1)  # [B]
                    batch_member_preds.append(preds.cpu().unsqueeze(0))  # [1, B]

                batch_member_preds = torch.cat(batch_member_preds, dim=0)  # [E, B]
                self.correct += (batch_member_preds == y.unsqueeze(0).cpu()).sum().item()
                self.total += y.size(0) * len(models)
                self.member_predictions.append(batch_member_preds)

        accuracy = self.correct / self.total
        all_member_preds = torch.cat(self.member_predictions, dim=1)
        return accuracy, all_member_preds
# Evaluate BatchEnsemble
evaluator = Evaluator(val_loader, device)
be_acc, be_member_preds = evaluator.evaluate_batchensemble(trained_batch_ensemble, num_members)
print(f"BatchEnsemble Accuracy: {be_acc:.4f}")
for m in trained_ensemble:
    m.to(device)
ce_acc, ce_member_preds = evaluator.evaluate_classical_ensemble(trained_ensemble)
print(f"Classical Ensemble Accuracy: {ce_acc:.4f}")