Batch Ensemble Networks¶
In this notebook, we:
A) Explain the idea behind Batch Ensembles.
B) Create a small Multi-Layer-Perceptron (MLP).
D) Train the MLP as an Ensemble and BatchEnsemble network on CIFAR10.
E) Compare accuracy and speed.
A) Introduction: What are Batch Ensembles?¶
Batch Ensembles are a way to efficiently approximate an ensemble of neural networks. Traditional ensembles require training and storing multiple independent networks, which is memory and computation expensive.
Key ideas:
Use shared base weights (and biases) for all ensemble members.
Introduce rank-1 multiplicative factors for each member.
Much faster and memory-efficient than classic ensembles.
Mathematically the classic forward of
$$y_i = W \circ x + b$$
transforms to
$$y_i = (W \circ (x \circ s_i^T)) \circ r_i+ b$$
Where $r_i, s_i$ are the rank-1 vectors for ensemble member $i$, and $\circ$ denotes element-wise multiplication.
What does the transformation do?¶
The parameters¶
num_members: The number of ensemble members to create.
s_mean: The mean used to initialize the input modulation factor s.
s_std: The standard deviation used to initialize the input modulation factor s.
r_mean: The mean used to initialize the output modulation factor r.
r_std: The standard deviation used to initialize the output modulation factor r.
The layers¶
With these parameters we can transform:
Linear layer into BatchEnsembleLinear layer
Conv2d layer into BatchEnsembleConv2d layer
The transformation keeps the dimensions and base weights, while adding rank-1 factors:
s: scales input-dimension per member
r: scales output-dimension per member
The base weights (weight) and bias (bias) are shared across all members, keeping memory usage minimal. The differences between members arise solely from their individual scaling factorys s and r.
B) Setup of the MLP¶
Standard Imports and Pytorch Setup
import time
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
Using device: cpu
Import CIFAR10 Dataset
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# add more transforms if desired
],
)
train_data = CIFAR10(root="./data", train=True, transform=transform, download=True)
val_data = CIFAR10(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)
print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")
Train samples: 50000, Val samples: 10000
The MLP Class
We create a MLP inherting basic functionality from the nn.Module parent class. The MLP has to hidden layers utilizing the ReLU activation function.
class MLP(nn.Module):
def __init__(self, in_dim: int = 3072, hidden: int = 128, out_dim: int = 10) -> None:
"""Initialize the MLP model with two hidden layers.
Args:
in_dim (int): Dimension of the input features. Default is 3072 (32x32x3 for CIFAR-10).
hidden (int): Number of neurons in the hidden layers. Default is 128.
out_dim (int): Dimension of the output features. Default is 10 (number of classes in CIFAR-10).
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the MLP model.
Before passing the input through the network, it flattens the input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_dim).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_dim).
"""
x = x.view(x.size(0), -1)
return self.net(x)
The Ensemble MLP
from probly.transformation.ensemble import ensemble
in_dim = 3 * 32 * 32
hidden = 128
num_members = 5
ensemble_mlp = ensemble(
base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
num_members=num_members,
)
The BatchEnsemble MLP
from probly.transformation import batchensemble
batch_ensemble_mlp = batchensemble(
base=MLP(in_dim=in_dim, hidden=hidden, out_dim=10),
num_members=num_members,
)
Let’s compare the different models now.
We start with a comparison of the base MLP and the BatchEnsemble MLP:
print(f"Base MLP:\n{MLP(in_dim=in_dim, hidden=hidden, out_dim=10)}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")
Base MLP:
MLP(
(net): Sequential(
(0): Linear(in_features=3072, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=10, bias=True)
)
)
BatchEnsemble MLP:
MLP(
(net): Sequential(
(0): BatchEnsembleLinear()
(1): ReLU()
(2): BatchEnsembleLinear()
(3): ReLU()
(4): BatchEnsembleLinear()
)
)
Then we compare the Ensemble MLP and the BatchEnsemble MLP:
print(f"Ensemble MLP:\n{ensemble_mlp}\n")
print(f"BatchEnsemble MLP:\n{batch_ensemble_mlp}")
Ensemble MLP:
ModuleList(
(0-4): 5 x MLP(
(net): Sequential(
(0): Linear(in_features=3072, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=128, bias=True)
(3): ReLU()
(4): Linear(in_features=128, out_features=10, bias=True)
)
)
)
BatchEnsemble MLP:
MLP(
(net): Sequential(
(0): BatchEnsembleLinear()
(1): ReLU()
(2): BatchEnsembleLinear()
(3): ReLU()
(4): BatchEnsembleLinear()
)
)
C) Training¶
Training Methods
While there is currently no training functionality implemented in probly we define the training methods below.
Base Training Method
def train_model(
model: nn.Module,
optimizer: torch.optim.Optimizer,
loss_function: nn.CrossEntropyLoss,
train_loader: DataLoader,
epochs: int = 10,
num_members: int | None = None,
) -> nn.Module:
for epoch in range(epochs):
t0 = time.perf_counter()
total_loss = 0.0
model.train()
for xb, yb in train_loader:
x = xb.to(device).float()
y = yb.to(device).long()
optimizer.zero_grad()
out = model(x)
if isinstance(num_members, int) and num_members > 0:
loss = 0.0
for e in range(num_members):
loss += loss_function(out[e], y)
loss = loss / num_members
else:
# fallback to standard loss computation
loss = loss_function(out, y)
loss.backward()
total_loss += loss.item()
optimizer.step()
avg_loss = total_loss / len(train_loader)
t1 = time.perf_counter()
print(f"Epoch {epoch + 1}/{epochs} trained in {t1 - t0} seconds.")
print(f"> Loss: {avg_loss}")
return model
def train_ensemble(
ensemble: MLP,
train_loader: DataLoader,
epochs: int = 10,
lr: float = 1e-3,
) -> nn.ModuleList:
model = nn.ModuleList()
for i, member in enumerate(ensemble):
print(f"\nTraining ensemble member {i + 1}/{len(ensemble)}")
member_i = member.to(device)
optimizer = optim.Adam(member_i.parameters(), lr=lr)
train_model(
member_i,
optimizer=optimizer,
loss_function=nn.CrossEntropyLoss(),
train_loader=train_loader,
epochs=epochs,
)
model.append(member_i)
return model
def train_batchensemble(
base_cls: MLP,
num_members: int,
train_loader: DataLoader,
epochs: int = 10,
lr: float = 1e-3,
) -> nn.Module:
model = base_cls.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
model = train_model(
model,
optimizer=optimizer,
loss_function=nn.CrossEntropyLoss(),
train_loader=train_loader,
epochs=epochs,
num_members=num_members,
)
return model
epochs = 1
lr = 1e-3
t0_batch_ensemble = time.perf_counter()
trained_batch_ensemble = train_batchensemble(
base_cls=batch_ensemble_mlp,
num_members=num_members,
train_loader=train_loader,
epochs=epochs,
lr=lr,
)
t1_batch_ensemble = time.perf_counter()
print(f"\nTrained BatchEnsemble model of size {num_members} in {t1_batch_ensemble - t0_batch_ensemble:.2f}s")
t0_ensemble = time.perf_counter()
trained_ensemble = train_ensemble(
ensemble=ensemble_mlp,
train_loader=train_loader,
epochs=epochs,
lr=lr,
)
t1_ensemble = time.perf_counter()
print(f"\nTrained classical ensemble of size {num_members} in {t1_ensemble - t0_ensemble:.2f}s")
class Evaluator:
def __init__(self, data_loader: torch.utils.data.DataLoader, device: str) -> None:
"""Initialize the Evaluator with a data loader and device.
Args:
data_loader (torch.utils.data.DataLoader): DataLoader for evaluation data.
device (str): Device to run the evaluation on ('cpu' or 'cuda').
"""
self.data_loader = data_loader
self.device = device
def _setup(self) -> None:
self.correct = 0
self.total = 0
self.member_predictions = []
def evaluate_batchensemble(self, model: nn.Module, num_members: int) -> tuple[float, torch.Tensor]:
"""Evaluate a BatchEnsemble model."""
self._setup()
model.to(self.device)
model.eval()
with torch.no_grad():
for xb, yb in self.data_loader:
x = xb.to(self.device).float()
y = yb.to(self.device).long()
out = model(x) # [E, B, out_dim]
preds = torch.argmax(out, dim=2) # [E, B]
self.correct += (preds == y.unsqueeze(0)).sum().item()
self.total += y.size(0) * num_members
self.member_predictions.append(preds.cpu())
accuracy = self.correct / self.total
all_member_preds = torch.cat(self.member_predictions, dim=1)
return accuracy, all_member_preds
def evaluate_classical_ensemble(self, models: nn.ModuleList) -> tuple[float, torch.Tensor]:
"""Evaluate a classical ensemble of models."""
self._setup()
for m in models:
m.to(self.device)
m.eval()
with torch.no_grad():
for xb, yb in self.data_loader:
x = xb.to(self.device).float()
y = yb.to(self.device).long()
batch_member_preds = []
for m in models:
out = m(x) # [B, out_dim]
preds = torch.argmax(out, dim=1) # [B]
batch_member_preds.append(preds.cpu().unsqueeze(0)) # [1, B]
batch_member_preds = torch.cat(batch_member_preds, dim=0) # [E, B]
self.correct += (batch_member_preds == y.unsqueeze(0).cpu()).sum().item()
self.total += y.size(0) * len(models)
self.member_predictions.append(batch_member_preds)
accuracy = self.correct / self.total
all_member_preds = torch.cat(self.member_predictions, dim=1)
return accuracy, all_member_preds
# Evaluate BatchEnsemble
evaluator = Evaluator(val_loader, device)
be_acc, be_member_preds = evaluator.evaluate_batchensemble(trained_batch_ensemble, num_members)
print(f"BatchEnsemble Accuracy: {be_acc:.4f}")
for m in trained_ensemble:
m.to(device)
ce_acc, ce_member_preds = evaluator.evaluate_classical_ensemble(trained_ensemble)
print(f"Classical Ensemble Accuracy: {ce_acc:.4f}")