__sklearn_is_fitted__ 作為開發者 API#

__sklearn_is_fitted__ 方法是 scikit-learn 中用於檢查估計器物件是否已擬合的慣例。此方法通常在建立於 scikit-learn 的基本類別(例如 BaseEstimator 或其子類別)之上的自訂估計器類別中實作。

開發人員應在所有方法(除了 fit)的開頭使用 check_is_fitted。如果他們需要自訂或加速檢查,他們可以如下所示實作 __sklearn_is_fitted__ 方法。

在此範例中,自訂估計器展示了 __sklearn_is_fitted__ 方法和 check_is_fitted 實用程式函式作為開發人員 API 的用法。__sklearn_is_fitted__ 方法通過驗證 _is_fitted 屬性的存在來檢查擬合狀態。

實作簡單分類器的範例自訂估計器#

此程式碼片段定義了一個名為 CustomEstimator 的自訂估計器類別,它擴展了 scikit-learn 中的 BaseEstimatorClassifierMixin 類別,並展示了 __sklearn_is_fitted__ 方法和 check_is_fitted 實用程式函式的用法。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_is_fitted


class CustomEstimator(BaseEstimator, ClassifierMixin):
    def __init__(self, parameter=1):
        self.parameter = parameter

    def fit(self, X, y):
        """
        Fit the estimator to the training data.
        """
        self.classes_ = sorted(set(y))
        # Custom attribute to track if the estimator is fitted
        self._is_fitted = True
        return self

    def predict(self, X):
        """
        Perform Predictions

        If the estimator is not fitted, then raise NotFittedError
        """
        check_is_fitted(self)
        # Perform prediction logic
        predictions = [self.classes_[0]] * len(X)
        return predictions

    def score(self, X, y):
        """
        Calculate Score

        If the estimator is not fitted, then raise NotFittedError
        """
        check_is_fitted(self)
        # Perform scoring logic
        return 0.5

    def __sklearn_is_fitted__(self):
        """
        Check fitted status and return a Boolean value.
        """
        return hasattr(self, "_is_fitted") and self._is_fitted

相關範例

歸納式聚類

歸納式聚類

具有自訂核的 SVM

具有自訂核的 SVM

scikit-learn 1.6 的發行重點

scikit-learn 1.6 的發行重點

中繼資料路由

中繼資料路由

由 Sphinx-Gallery 產生之圖庫