Source code for probly.representation.sampling.sklearn_sampler

"""Sampling preparation for sklearn."""

from __future__ import annotations

from typing import TYPE_CHECKING

from sklearn.base import BaseEstimator

from . import sampler

if TYPE_CHECKING:
    from lazy_dispatch.isinstance import LazyType
    from pytraverse import State


def _enforce_fitted_already(obj: BaseEstimator, state: State) -> tuple[BaseEstimator, State]:
    """Should check that the sklearn estimator is fitted already.

    Now we check for the presence of the `n_features_in_` attribute,
    which is set by all sklearn estimators when they are fitted.

    There is no standard way to check if a sklearn estimator is fitted.
    See: https://scikit-learn.org/stable/glossary.html#term-fitted
    """
    if not hasattr(obj, "n_features_in_"):
        msg = "The sklearn estimator must be fitted already before sampling."
        raise ValueError(msg)
    return obj, state


[docs] def register_forced_fitted_already_mode(cls: LazyType) -> None: """Register a class to be forced into fitted already mode during sampling.""" sampler.sampling_preparation_traverser.register( cls, _enforce_fitted_already, )
register_forced_fitted_already_mode( BaseEstimator, )