probly.train.evidential.torch.unified_evidential_train¶
- probly.train.evidential.torch.unified_evidential_train(mode: Literal['PostNet', 'NatPostNet', 'EDL', 'PrNet', 'IRD', 'DER', 'RPN'], model: nn.Module, dataloader: DataLoader, loss_fn: Callable[..., torch.Tensor] | None = None, oodloader: DataLoader | None = None, class_count: torch.Tensor | None = None, epochs: int = 5, lr: float = 0.001, device: str = 'cpu') None[source]¶
Trains a given Neural Network using different learning approaches, depending on the approach of a selected paper.
- Parameters:
mode – Identifier of the paper-based training approach to be used. Must be one of: “PostNet”, “NatPostNet”, “EDL”, “PrNet”, “IRD”, “DER” or “RPN”.
model – The neural network to be trained.
dataloader – Pytorch.Dataloader providing the In-Distributtion training samples and corresponding labels.
loss_fn – Loss functions used for training. The inputs of each loss-functions depends on the selected mode
oodloader – Pytorch.Dataloader providing the Out-Of-Distributtion training samples and corresponding labels. This is only required for certain modes such as “PrNet”
class_count – Tensor containing the number of samples per class.
epochs – Number of training epochs.
lr – Learning rate used by the optimizer.
device – Device on which the model is trained (e.g. “cpu” or “cuda”)
- Returns:
None. The function performs training of the provided model and does not return a value. But prints the total-losses per Epoch.