Source code for probly.evaluation.ood

"""Unified OOD evaluation API for probly."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import sklearn.metrics as sm
from sklearn.metrics import precision_recall_curve, roc_curve

from probly.visualization.ood import plot_histogram, plot_pr_curve, plot_roc_curve

if TYPE_CHECKING:
    from collections.abc import Callable

    from matplotlib.figure import Figure


[docs] def out_of_distribution_detection_auroc(in_distribution: np.ndarray, out_distribution: np.ndarray) -> float: """Perform out-of-distribution detection using prediction functionals from id and ood data. This can be epistemic uncertainty, as is common, but also e.g. softmax confidence. Args: in_distribution: in-distribution prediction functionals out_distribution: out-of-distribution prediction functionals Returns: auroc: float, area under the roc curve """ preds = np.concatenate((in_distribution, out_distribution)) labels = np.concatenate((np.zeros(len(in_distribution)), np.ones(len(out_distribution)))) auroc = sm.roc_auc_score(labels, preds) return float(auroc)
[docs] def out_of_distribution_detection_aupr(in_distribution: np.ndarray, out_distribution: np.ndarray) -> float: """Perform out-of-distribution detection using AUPR (Area Under the Precision-Recall Curve). This metric evaluates how well the model distinguishes between in- and out-of-distribution samples, focusing more on positive class (OOD) precision and recall. Args: in_distribution: in-distribution prediction functionals out_distribution: out-of-distribution prediction functionals Returns: aupr: float, area under the precision-recall curve """ preds = np.concatenate((in_distribution, out_distribution)) labels = np.concatenate((np.zeros(len(in_distribution)), np.ones(len(out_distribution)))) aupr = sm.average_precision_score(labels, preds) return float(aupr)
[docs] def out_of_distribution_detection_fpr_at_x_tpr( in_distribution: np.ndarray, out_distribution: np.ndarray, tpr_target: float = 0.95, ) -> float: """Perform out-of-distribution detection using false positive rate (FPR) at a given true positive rate. If no thresholds are specified, the default tpr_target is 0.95. This can be epistemic uncertainty, as is common, but also e.g. softmax confidence. Args: in_distribution: numpy.ndarray, scores for in-distribution samples out_distribution: numpy.ndarray, scores for out-of-distribution samples tpr_target: target TPR value in [0, 1], e.g. 0.95 Returns: fpr_at_target: float, FPR at the first threshold where TPR >= tpr_target Notes: - Assumes that larger scores correspond to the positive class (out-of-distribution). - If tpr_target cannot be reached, a ValueError is raised. """ if not 0.0 < tpr_target <= 1.0: msg = f"tpr_target must be in the interval (0, 1], got {tpr_target}." raise ValueError(msg) preds = np.concatenate((in_distribution, out_distribution)) labels = np.concatenate( (np.zeros(len(in_distribution)), np.ones(len(out_distribution))), ) fpr, tpr, _ = sm.roc_curve(labels, preds) idxs = np.where(tpr >= tpr_target)[0] if len(idxs) == 0: msg = f"Could not achieve TPR >= {tpr_target:.3f} with given scores." raise ValueError(msg) first_idx = idxs[0] fpr_at_target = fpr[first_idx] return float(fpr_at_target)
[docs] def out_of_distribution_detection_fnr_at_x_tpr( in_distribution: np.ndarray, out_distribution: np.ndarray, tpr_target: float = 0.95, ) -> float: """Perform out-of-distribution detection using false negative rate at user given true positive rate. If no thresholds are specified, the default tpr_target is 0.95. Args: in_distribution: in-distribution prediction functionals out_distribution: out-of-distribution prediction functionals tpr_target: target TPR value in [0, 1], e.g. 0.95 Returns: fnr@X: float, FNR at the first threshold where TPR >= tpr_target """ preds = np.concatenate((in_distribution, out_distribution)) labels = np.concatenate((np.zeros(len(in_distribution)), np.ones(len(out_distribution)))) _, tpr, _ = sm.roc_curve(labels, preds) idx = np.where(tpr >= tpr_target)[0] return float(1.0 - tpr[idx[0]]) if len(idx) else 1.0
STATIC_METRICS: dict[str, Callable[[np.ndarray, np.ndarray], float]] = { "auroc": out_of_distribution_detection_auroc, "aupr": out_of_distribution_detection_aupr, } DYNAMIC_METRICS: dict[str, Callable[[np.ndarray, np.ndarray, float], float]] = { "fpr": out_of_distribution_detection_fpr_at_x_tpr, "fnr": out_of_distribution_detection_fnr_at_x_tpr, } def _validate_threshold(threshold: float) -> None: """Validate threshold is in (0, 1].""" if not 0 < threshold <= 1: msg = "threshold must be in (0,1]" raise ValueError(msg) def _validate_base_metric(base: str) -> None: """Validate base metric is known.""" if base not in DYNAMIC_METRICS: msg = f"unknown dynamic metric '{base}'" raise ValueError(msg)
[docs] def parse_dynamic_metric(spec: str) -> tuple[str, float]: """Parse dynamic metric specification. Examples: fpr@0.8 fnr@95% fpr -> default threshold is 0.95 fnr -> default threshold is 0.95 """ try: spec = spec.lower().strip() if "@" not in spec: base = spec threshold = 0.95 else: base, t = spec.split("@") base = base.lower().strip() t = t.strip().lower() threshold = float(t[:-1]) / 100 if t.endswith("%") else float(t) _validate_threshold(threshold) _validate_base_metric(base) except Exception as e: msg = f"Invalid metric specification '{spec}'. Use e.g. 'fpr', 'fpr@0.8' or 'fnr@90%'." raise ValueError(msg) from e else: return base, threshold
[docs] def evaluate_ood( in_distribution: np.ndarray | list[float], out_distribution: np.ndarray | list[float], metrics: str | list[str] | None = None, ) -> dict[str, float]: """Unified OOD evaluation API. Provides backward compatibility while supporting multiple metrics. Parameters: in_distribution : Scores for in-distribution samples. out_distribution : Scores for out-of-distribution samples. metrics : str, list of str, or None - None or "auroc": Returns single AUROC value (backward compatible) - "all": Returns dict with all available metrics - list: Returns dict with specified metrics Returns: dict[str, float] Dictionary mapping metric names to values. If metrics is None or "auroc", the dict contains only the "auroc" entry. """ in_s = np.asarray(in_distribution) out_s = np.asarray(out_distribution) if metrics is None: return {"auroc": STATIC_METRICS["auroc"](in_s, out_s)} if isinstance(metrics, str): if metrics == "auroc": return {"auroc": STATIC_METRICS["auroc"](in_s, out_s)} metric_list = [*STATIC_METRICS.keys(), "fpr", "fnr"] if metrics == "all" else [metrics] else: # metrics is a list list[str] metric_list = list(metrics) results: dict[str, float] = {} for metric_name in metric_list: metric_name_lower = metric_name.lower().strip() if metric_name_lower in STATIC_METRICS: results[metric_name] = STATIC_METRICS[metric_name_lower](in_s, out_s) elif metric_name_lower in DYNAMIC_METRICS or "@" in metric_name_lower: base, thr = parse_dynamic_metric(metric_name_lower) results[metric_name] = DYNAMIC_METRICS[base](in_s, out_s, thr) else: msg = ( f"Unknown metric '{metric_name}'. " f"Available: {list(STATIC_METRICS.keys())} " " + dynamic metrics 'fpr', 'fnr' (optionally with @value)." ) raise ValueError(msg) return results
[docs] def visualize_ood( in_distribution: np.ndarray | list[float], out_distribution: np.ndarray | list[float], plot_types: list[str] | None = None, invert_scores: bool = True, ) -> dict[str, Figure]: """Generate visualization plots from OOD scores. Parameters: in_distribution : Scores for in-distribution samples. out_distribution : Scores for out-of-distribution samples. plot_types : list[str], optional List of specific plots to return (e.g. ['roc', 'hist', 'pr']). If None, all plots are generated. invert_scores : bool If True (default), assumes scores are 'Confidence' (High = ID). They will be inverted (1.0 - score) for metrics where OOD is the positive class. If False, assumes scores are 'Anomaly Scores' (High = OOD). Returns: Dict containing matplotlib Figures for the requested plots. """ id_s = np.asarray(in_distribution) ood_s = np.asarray(out_distribution) available_plots = {"hist", "roc", "pr"} requested_plots = available_plots if plot_types is None else set(plot_types).intersection(available_plots) figures = {} if "hist" in requested_plots: figures["hist"] = plot_histogram(id_scores=id_s, ood_scores=ood_s) if "roc" in requested_plots or "pr" in requested_plots: if invert_scores: id_s_eval = 1.0 - id_s ood_s_eval = 1.0 - ood_s else: id_s_eval = id_s ood_s_eval = ood_s labels = np.concatenate([np.zeros(len(id_s_eval)), np.ones(len(ood_s_eval))]) preds = np.concatenate([id_s_eval, ood_s_eval]) if "roc" in requested_plots: fpr, tpr, _ = roc_curve(labels, preds) auroc = sm.auc(fpr, tpr) idx_95 = np.where(tpr >= 0.95)[0] fpr95 = fpr[idx_95[0]] if len(idx_95) > 0 else None figures["roc"] = plot_roc_curve(fpr=fpr, tpr=tpr, auroc=auroc, fpr95=fpr95) if "pr" in requested_plots: precision, recall, _ = precision_recall_curve(labels, preds) aupr = sm.auc(recall, precision) figures["pr"] = plot_pr_curve(recall=recall, precision=precision, aupr=aupr) return figures