Utility Functions

Introduction

This notebook provides a practical introduction to the core utility functions in probly. These helpers are essential building blocks for training probabilistic models and quantifying uncertainty.

We will focus on two main categories:

  • Model traversal functions, which inspect a model’s architecture

  • Uncertainty quantification functions, which compute meaningful uncertainty scores from model predictions


Key Utility Functions in probly

1. collect_kl_divergence (for BNNs)

What it does: Automatically traverses a Bayesian Neural Network and sums the KL divergence from each Bayesian layer.

Why it’s useful: This function is critical for computing the ELBO loss during training.


2. total_entropy, conditional_entropy, mutual_information

What they do: These functions take a set of predictions (for example, from an ensemble) and decompose predictive uncertainty.

Why they’re useful: They allow you to separately measure:

  • Aleatoric uncertainty (inherent randomness in the data)

  • Epistemic uncertainty (uncertainty due to limited model knowledge)


3. evidential_uncertainty (for Evidential Models)

What it does: Computes an uncertainty score directly from the evidence vector produced by an evidential model.

Why it’s useful: It provides a fast, single-pass way to determine whether a model is uncertain about its prediction.

# Example 1: collect_kl_divergence for Bayesian Neural Networks
from probly.transformation import bayesian
from probly.train.bayesian.torch import collect_kl_divergence
import torch
from torch import nn

# Create and transform a model to Bayesian
model = nn.Sequential(
    nn.Linear(10, 32),
    nn.ReLU(),
    nn.Linear(32, 3)
)
bnn_model = bayesian(model)

# Create dummy input
inputs = torch.randn(4, 10)

# Forward pass (this samples weights from distributions)
outputs = bnn_model(inputs)

# Collect KL divergence from all Bayesian layers
kl = collect_kl_divergence(bnn_model)

print(f"Total KL Divergence: {kl.item():.4f}")
# Example 2: Decomposing uncertainty with entropy functions
from probly.quantification.classification import total_entropy, conditional_entropy, mutual_information
import numpy as np

# Simulated predictions from a 3-member ensemble for 2 instances, 3 classes
# Shape: (num_instances, num_samples, num_classes)
ensemble_predictions = np.array([
    # Instance 1: High agreement (low epistemic uncertainty)
    [[0.8, 0.1, 0.1],
     [0.75, 0.15, 0.1],
     [0.85, 0.1, 0.05]],
    
    # Instance 2: High disagreement (high epistemic uncertainty)
    [[0.7, 0.2, 0.1],
     [0.2, 0.7, 0.1],
     [0.1, 0.2, 0.7]]
])

# Compute uncertainty metrics
total_ent = total_entropy(ensemble_predictions)
cond_ent = conditional_entropy(ensemble_predictions)  # Aleatoric uncertainty
mutual_info = mutual_information(ensemble_predictions)  # Epistemic uncertainty

print("Instance 1 (models agree):")
print(f"  Total Entropy: {total_ent[0]:.4f}")
print(f"  Aleatoric Uncertainty: {cond_ent[0]:.4f}")
print(f"  Epistemic Uncertainty: {mutual_info[0]:.4f}")

print("\nInstance 2 (models disagree):")
print(f"  Total Entropy: {total_ent[1]:.4f}")
print(f"  Aleatoric Uncertainty: {cond_ent[1]:.4f}")
print(f"  Epistemic Uncertainty: {mutual_info[1]:.4f}")
# Example 3: evidential_uncertainty for Evidential Models
from probly.quantification.classification import evidential_uncertainty
import numpy as np

# Simulated evidence vectors (alpha values) from an evidential model
# High evidence = confident, low evidence = uncertain

# Confident prediction: lots of evidence for class 0
confident_evidence = np.array([[100.0, 2.0, 3.0]])

# Uncertain prediction: little evidence for any class
uncertain_evidence = np.array([[1.0, 1.0, 1.0]])

# Compute uncertainty scores
confident_uncertainty = evidential_uncertainty(confident_evidence)
uncertain_uncertainty = evidential_uncertainty(uncertain_evidence)

print(f"Confident prediction evidence: {confident_evidence[0]}")
print(f"  Uncertainty score: {confident_uncertainty[0]:.4f}")

print(f"\nUncertain prediction evidence: {uncertain_evidence[0]}")
print(f"  Uncertainty score: {uncertain_uncertainty[0]:.4f}")