probly.train.dare.torch.dare_regularizer

probly.train.dare.torch.dare_regularizer(model: Module, device: device | str, loss: Tensor, threshold: Tensor | float) Tensor[source]

Compute the DARE anti-regularization term following Algorithm 1.

Parameters:
  • model – The DARE model.

  • device – The device of the model.

  • loss – The current loss value, used for the switching condition.

  • threshold – The threshold at or below which anti-regularization activates.

Returns:

The anti-regularization term when loss <= threshold, else 0.0.