Source code for probly.models.evidential.torch

"""Models for evidential deep learning using PyTorch."""

from __future__ import annotations

import torch
from torch import nn

import probly.layers.evidential.torch as t


[docs] class EDLModel(nn.Module): """Simple model for evidential deep learning (EDL) classification. Combines an encoder with an evidential prediction head to model classification uncertainty following Sensoy et al. (2018). Reference: Sensoy et al., "Evidential Deep Learning to Quantify Classification Uncertainty", NeurIPS 2018. https://arxiv.org/abs/1806.01768 """ def __init__( self, encoder: nn.Module, head: nn.Module | None = None, latent_dim: int | None = None, num_classes: int = 10, ) -> None: """Initialize the EDLModel for evidential classification. Args: encoder: Encoder module mapping inputs to latent space. head: Head module for evidential output (defaults to EDLHead). latent_dim: Dimension of the latent space. num_classes: Number of output classes. """ super().__init__() if latent_dim is None: latent_dim = getattr(encoder, "latent_dim", None) if not isinstance(latent_dim, int): msg = "latent_dim must be specified if encoder has no attribute 'latent_dim'" raise ValueError(msg) if head is None: head = t.EDLHead(latent_dim=latent_dim, num_classes=num_classes) self.encoder = encoder self.head = head
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through encoder and head. Args: x: Input tensor compatible with the encoder. Returns: Output tensor from the head module. """ features = self.encoder(x) return self.head(features)
[docs] class PrNetModel(nn.Module): """Dirichlet Prior Network model for evidential classification. Combines an encoder with a Dirichlet Prior Network head to model distributional uncertainty following Malinin and Gales (2018). Reference: Malinin and Gales, "Predictive Uncertainty Estimation via Prior Networks", NeurIPS 2018. https://arxiv.org/abs/1802.10501 """ def __init__( self, encoder: nn.Module, head: nn.Module | None = None, latent_dim: int = 256, num_classes: int = 10, ) -> None: """Initialize the Dirichlet Prior Network model. Args: encoder: Encoder module mapping inputs to latent space. head: Dirichlet Prior Network head producing concentration parameters. If None, defaults to ``t.PrNetHead``. latent_dim: Dimensionality of the latent space. num_classes: Number of output classes. """ super().__init__() if head is None: head = t.PrNetHead( latent_dim=latent_dim, num_classes=num_classes, ) self.encoder = encoder self.head = head
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute Dirichlet parameters for input samples.""" features = self.encoder(x) return self.head(features)
[docs] class PostNetModel(nn.Module): """Posterior Network model for evidential classification. Combines an encoder with class-conditional normalizing flows to model uncertainty-aware predictions following Malinin and Gales (2020). Reference: Malinin and Gales, "Posterior Networks: Uncertainty Estimation without OOD Samples via Density-Based Pseudo-Counts", NeurIPS 2020. https://arxiv.org/abs/2006.09239 """ class_counts: torch.Tensor def __init__( self, encoder: nn.Module, latent_dim: int = 6, num_classes: int = 10, flow: t.BatchedRadialFlowDensity | None = None, class_counts: torch.Tensor | None = None, ) -> None: """Initialize a Posterior Network model. Args: encoder: Encoder mapping inputs to a latent space. latent_dim: Dimensionality of the latent space. num_classes: Number of output classes. flow: Class-conditional normalizing flow. class_counts: Empirical class counts used as a prior. """ super().__init__() self.encoder = encoder self.num_classes = num_classes self.latent_dim = latent_dim if flow is None: flow = t.BatchedRadialFlowDensity(num_classes=num_classes, latent_dim=latent_dim, flow_length=6) self.flow = flow if class_counts is None: class_counts = torch.ones(num_classes) self.register_buffer("class_counts", class_counts)
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass of the Posterior Network. Args: x: Input tensor of shape (batch_size, ...) compatible with the encoder. Returns: alpha: Dirichlet concentration parameters of shape. p_mean: Predictive mean of the Dirichlet distribution. z: Latent representation of the input. """ features = self.encoder(x) log_dens = self.flow.log_prob(features) dens = log_dens.exp() beta = dens * self.class_counts.unsqueeze(0) alpha = beta + 1.0 alpha0 = alpha.sum(dim=1, keepdim=True) p_mean = alpha / alpha0 return alpha, p_mean, features
[docs] class NatPNModel(nn.Module): """Natural Posterior Network for evidential deep learning with normalizing flows. Combines encoder, normalizing flow density, and head for uncertainty quantification following Charpentier et al., 2022. Users can provide custom encoders for different data modalities. Reference: Charpentier et al., "Natural Posterior Network", NeurIPS 2022. https://arxiv.org/abs/2105.04471 """ def __init__( self, encoder: nn.Module, head: nn.Module | None = None, latent_dim: int | None = None, flow_length: int = 4, certainty_budget: float = 2.0, ) -> None: """Initialize the NatPN model. Args: encoder: Encoder mapping inputs to a latent space. head: Prediction head operating on latent features. If None, a default classification head is used. latent_dim: Dimensionality of the latent space. If None, inferred from the encoder. flow_length: Number of radial flow layers. certainty_budget: Budget controlling certainty calibration. """ super().__init__() if latent_dim is None: latent_dim = getattr(encoder, "latent_dim", None) if not isinstance(latent_dim, int): msg = "latent_dim must be specified if encoder has no attribute 'latent_dim'" raise ValueError(msg) if head is None: head = t.NatPNClassHead(latent_dim=latent_dim, num_classes=10) self.encoder = encoder self.head = head self.flow = t.RadialFlowDensity( latent_dim=latent_dim, flow_length=flow_length, ) if certainty_budget is None: certainty_budget = float(latent_dim) self.certainty_budget = certainty_budget
[docs] def freeze_encoder(self) -> None: """Freeze encoder weights (for transfer learning).""" for param in self.encoder.parameters(): param.requires_grad = False
[docs] def unfreeze_encoder(self) -> None: """Unfreeze encoder weights (for fine-tuning).""" for param in self.encoder.parameters(): param.requires_grad = True
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass through encoder, flow, and head. Args: x: Input tensor compatible with the encoder. Returns: Dictionary with predictions from the head (including alpha for classification, or mean/var for regression) along with latent space information. """ features = self.encoder(x) # [B, latent_dim] log_pz = self.flow.log_prob(features) # [B] return self.head( features=features, log_pz=log_pz, certainty_budget=self.certainty_budget, )
[docs] class IRDModel(nn.Module): """Full model for evidential classification using Dirichlet outputs. Combines an encoder with a Dirichlet head to produce concentration parameters for uncertainty-aware classification following Tsiligkaridis (2019). Reference: Tsiligkaridis, "Information Robust Dirichlet Networks", 2019. https://arxiv.org/abs/1910.04819 """ def __init__( self, encoder: nn.Module, head: nn.Module | None = None, latent_dim: int = 128, num_classes: int = 10, ) -> None: """Initialize the IRDModel for Dirichlet-based classification. Args: encoder: Encoder module mapping inputs to latent space. head: Dirichlet head module mapping latent features to alpha parameters. If None, defaults to ``t.IRDHead``. latent_dim: Dimensionality of the latent space. num_classes: Number of output classes. """ super().__init__() if head is None: head = t.IRDHead(latent_dim=latent_dim, num_classes=num_classes) self.encoder = encoder self.head = head
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through encoder and head. Args: x: Input tensor of shape (batch_size, input_dim). Returns: Alpha parameters of shape (batch_size, num_classes). """ features = self.encoder(x) alpha = self.head(features) return alpha
[docs] class EvidentialRegressionModel(nn.Module): """Full evidential regression model combining encoder and evidential head. Implements evidential regression for uncertainty-aware prediction, inspired by Deep Evidential Regression and Regression Prior Networks. References: Amini et al., "Deep Evidential Regression", NeurIPS 2020. https://arxiv.org/abs/1910.02600 Malinin et al., "Regression Prior Networks", NeurIPS 2020. https://arxiv.org/abs/2006.11590 """ def __init__( self, encoder: nn.Module, head: nn.Module | None = None, latent_dim: int | None = None, ) -> None: """Initialize the EvidentialRegressionModel. Args: encoder: Encoder module mapping inputs to a latent space. head: Evidential regression head producing distribution parameters. If None, defaults to ``t.RegressionHead``. latent_dim: Dimensionality of the latent space. If None, inferred from ``encoder.latent_dim``. """ super().__init__() if latent_dim is None: latent_dim = getattr(encoder, "latent_dim", None) if not isinstance(latent_dim, int): msg = "latent_dim must be specified if encoder has no attribute 'latent_dim'" raise ValueError(msg) if head is None: head = t.RegressionHead(latent_dim=latent_dim) self.encoder = encoder self.head = head
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass through encoder and head.""" features = self.encoder(x) return self.head(features)