Custom Loss Functions

This notebook provides a practical introduction to the specialized loss functions in probly. While standard losses like nn.CrossEntropyLoss are sufficient for deterministic models, probabilistic models often require custom loss functions to handle uncertainty.

We will cover three key types of custom losses:

  • Negative Log-Likelihood (NLL) Losses: Adaptations for probabilistic outputs.

  • Evidential Losses: Specialized functions for models that learn “evidence.”

  • Calibration-Aware Losses: Losses that directly optimize for model calibration.


1. Negative Log-Likelihood (NLL) Losses

NLL losses are a foundational concept in training probabilistic models. Instead of just penalizing wrong predictions, they evaluate how well the entire predicted distribution explains the true target.

Example: The ELBO Loss for Bayesian Neural Networks

A Bayesian Neural Network (BNN) requires a unique loss function that balances two goals:

  1. Fit the data: Make accurate predictions (similar to a standard loss).

  2. Stay simple: Keep the weight distributions close to a simple prior distribution.

The Evidence Lower Bound (ELBO) loss achieves this.

import torch
from torch import nn
import torch.nn.functional as F

from probly.train.bayesian.torch import collect_kl_divergence
from probly.transformation import bayesian


class ELBOLoss(nn.Module):
    """Evidential Lower Bound Loss."""

    def __init__(self, kl_penalty: float = 1e-5) -> None:
        """Initialize the loss.

        Args:
            kl_penalty: The penalty weight for the KL divergence term.

        """
        super().__init__()
        self.kl_penalty = kl_penalty

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor, kl: torch.tensor) -> torch.Tensor:
        """Compute the ELBO loss.

        Args:
            inputs: The input tensor.
            targets: The target tensor.
            kl: The KL divergence tensor.

        Returns:
            The calculated loss.

        """
        # 1. Standard Cross-Entropy
        cross_entropy_loss = F.cross_entropy(inputs, targets)

        # 2. KL Divergence Regularizer
        kl_divergence = self.kl_penalty * kl

        return cross_entropy_loss + kl_divergence

The ELBOLoss combines a standard cross-entropy loss with a KL Divergence term, which penalizes the model for having weight distributions that are too complex or far from the initial prior For more information on how Bayesian models work, see the Bayesian Transformation tutorial.

2. Evidential Losses

Evidential Deep Learning models do not output probabilities directly. Instead, they output evidence for each class, which requires specialized loss functions.

Example: Evidential Losses for Classification and Regression

The probly library provides custom loss functions for evidential learning, based on the original research papers:

  • EvidentialLogLoss (Classification) Adapts the standard log loss to work with evidence scores (alpha) instead of probabilities.

  • EvidentialNIGNLLLoss (Regression) A more complex negative log-likelihood (NLL) loss that handles the four parameters of an evidential regression model: gamma, nu, alpha, and beta.

Training an Evidential Model

The training loop for an evidential model typically combines:

  1. An evidential NLL loss (classification or regression), and

  2. A regularization term that encourages the model to remain uncertain on out-of-distribution data. The total loss is a weighted sum of these two components.

# Example: Training with ELBO Loss
import torch
from torch import nn

# Create a simple model and transform it to Bayesian
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 3))
bnn_model = bayesian(model)

# Create the ELBO loss
criterion = ELBOLoss(kl_penalty=1e-5)

# Dummy data (8 samples, 10 features, 3 classes)
inputs = torch.randn(8, 10)
targets = torch.randint(0, 3, (8,))

# Forward pass
outputs = bnn_model(inputs)

# Collect KL divergence from all Bayesian layers
kl = collect_kl_divergence(bnn_model)

# Compute loss
loss = criterion(outputs, targets, kl)

print(f"ELBO Loss: {loss.item():.4f}")
print(f"  - Cross-Entropy component: {nn.functional.cross_entropy(outputs, targets).item():.4f}")
print(f"  - KL Divergence component: {(criterion.kl_penalty * kl).item():.6f}")
print(f"\nTotal KL from all layers: {kl.item():.4f}")

For full implementations, see the Evidential Classification and Evidential Regression tutorials.

3. Calibration-Aware Losses

Sometimes, the most effective way to achieve good calibration is to include a calibration objective directly in the loss function. This forces the model to optimize calibration as part of training.

Example: Label Relaxation

Label Relaxation is a simple but effective technique for reducing over-confidence and improving model calibration. Instead of using hard one-hot encoded labels (e.g., [0, 0, 1]), the labels are softened:

  • The true class is assigned a slightly lower value (e.g., 0.9).

  • The remaining probability mass (e.g., 0.1) is distributed across the other classes.

This discourages the model from producing extreme, over-confident predictions.

The probly library provides a direct implementation of this approach through the LabelRelaxationLoss. The LabelRelaxationLoss can be used as a drop-in replacement for standard losses like nn.CrossEntropyLoss, making it easy to integrate into existing training pipelines.

# Example: Training with Label Relaxation
import torch
from torch import nn

from probly.train.calibration.torch import LabelRelaxationLoss

# Create a simple classifier
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 3))

# Use Label Relaxation instead of CrossEntropyLoss
# alpha=0.1 means: true class gets 0.9, other classes share 0.1
criterion = LabelRelaxationLoss(alpha=0.1)

# Standard CrossEntropyLoss for comparison
standard_criterion = nn.CrossEntropyLoss()

# Dummy data
inputs = torch.randn(8, 10)
targets = torch.randint(0, 3, (8,))

# Forward pass
outputs = model(inputs)

# Compare losses
relaxed_loss = criterion(outputs, targets)
standard_loss = standard_criterion(outputs, targets)

print(f"Standard CrossEntropy Loss: {standard_loss.item():.4f}")
print(f"Label Relaxation Loss:      {relaxed_loss.item():.4f}")

By optimizing this “softer” objective, the model learns to produce better-calibrated probability estimates.

For a full implementation, see the Label Relaxation Calibration tutorial.