probly.calibration.isotonic_regression.torch

Torch Implementations for Isotonic Regression.

Classes

IsotonicRegressionCalibrator(base_model, ...)

Class for the isotonic regression calibration.

class probly.calibration.isotonic_regression.torch.IsotonicRegressionCalibrator(base_model, use_logits)[source]

Bases: object

Class for the isotonic regression calibration.

Parameters:
  • base_model (nn.Module)

  • use_logits (bool)

fit(calibration_set)[source]

Fit the regression function to the model outputs.

Parameters:

calibration_set (DataLoader) – The set that should be used for the calibration

Return type:

None

predict(x)[source]

Make calibrated predictions on the input x.

Parameters:

x (Tensor) – The input for the model to make predictions on

Returns:

The calibrated probabilities for the prediction

Return type:

calibrated_probs