Source code for probly.transformation.ensemble.common
"""Shared ensemble implementation."""
from __future__ import annotations
from typing import TYPE_CHECKING
from lazy_dispatch import lazydispatch
if TYPE_CHECKING:
from collections.abc import Callable
from lazy_dispatch.isinstance import LazyType
from probly.predictor import Predictor
@lazydispatch
def ensemble_generator[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]:
"""Generate an ensemble from a base model."""
msg = f"No ensemble generator is registered for type {type(base)}"
raise NotImplementedError(msg)
def register(cls: LazyType, generator: Callable) -> None:
"""Register a class which can be used as a base for an ensemble."""
ensemble_generator.register(cls=cls, func=generator)
[docs]
def ensemble[T: Predictor](base: T, num_members: int, reset_params: bool = True) -> T:
"""Create an ensemble predictor from a base predictor based on :cite:`lakshminarayananSimpleScalable2017`.
Args:
base: Predictor, The base model to be used for the ensemble.
num_members: The number of members in the ensemble.
reset_params: Whether to reset the parameters of each member.
Returns:
Predictor, The ensemble predictor.
"""
return ensemble_generator(base, num_members=num_members, reset_params=reset_params)