中繼資料路由#

此文件說明如何使用 scikit-learn 中的 中繼資料路由機制,將中繼資料路由到使用它們的估計器、評分器和 CV 分割器。

為了更好地理解以下文件,我們需要介紹兩個概念:路由器和消費者。路由器是一個將一些給定的資料和中繼資料轉發到其他物件的物件。在大多數情況下,路由器是一個 元估計器,即將另一個估計器作為參數的估計器。諸如 sklearn.model_selection.cross_validate 之類的函數,它將一個估計器作為參數並轉發資料和中繼資料,也是一個路由器。

另一方面,消費者是一個接受和使用一些給定中繼資料的物件。例如,一個在其 fit 方法中考慮 sample_weight 的估計器是 sample_weight 的消費者。

一個物件有可能同時是路由器和消費者。例如,元估計器可能會在某些計算中考慮 sample_weight,但它也可能會將其路由到基礎估計器。

首先是一些匯入和一些用於腳本剩餘部分的隨機資料。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from pprint import pprint

import numpy as np

from sklearn import set_config
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    TransformerMixin,
    clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
    ,
    ,
    ,
    ,
)
from sklearn.utils.validation import check_is_fitted

n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)

只有明確啟用中繼資料路由時,它才可用

set_config(enable_metadata_routing=True)

此實用程式函數是一個虛擬函數,用於檢查是否傳遞了中繼資料

def check_metadata(obj, **kwargs):
    for key, value in kwargs.items():
        if value is not None:
            print(
                f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
            )
        else:
            print(f"{key} is None in {obj.__class__.__name__}.")

一個實用程式函數,用於漂亮地列印物件的路由資訊

def print_routing(obj):
    pprint(obj.get_metadata_routing()._serialize())

使用中繼資料的估計器#

在這裡,我們示範估計器如何公開所需的 API,以支援中繼資料路由作為消費者。想像一個簡單的分類器,它在其 fit 方法中接受 sample_weight 作為中繼資料,並在其 predict 方法中接受 groups

class ExampleClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        # all classifiers need to expose a classes_ attribute once they're fit.
        self.classes_ = np.array([0, 1])
        return self

    def predict(self, X, groups=None):
        check_metadata(self, groups=groups)
        # return a constant value of 1, not a very smart classifier!
        return np.ones(len(X))

上述估計器現在擁有使用中繼資料所需的一切。這是透過在 BaseEstimator 中完成的一些魔術來完成的。現在,上述類別公開了三個方法:set_fit_requestset_predict_requestget_metadata_routing。還有一個 set_score_request 用於 sample_weight,它存在是因為 ClassifierMixin 實作了一個接受 sample_weightscore 方法。這同樣適用於繼承自 RegressorMixin 的回歸器。

預設情況下,不要求任何中繼資料,我們可以將其視為

print_routing(ExampleClassifier())
{'fit': {'sample_weight': None},
 'predict': {'groups': None},
 'score': {'sample_weight': None}}

上述輸出表示 ExampleClassifier 沒有要求 sample_weightgroups,如果路由器收到這些中繼資料,則會引發錯誤,因為使用者尚未明確設定是否需要它們。對於 score 方法中的 sample_weight 也是如此,該方法繼承自 ClassifierMixin。為了明確設定這些中繼資料的請求值,我們可以使用這些方法

est = (
    ExampleClassifier()
    .set_fit_request(sample_weight=False)
    .set_predict_request(groups=True)
    .set_score_request(sample_weight=False)
)
print_routing(est)
{'fit': {'sample_weight': False},
 'predict': {'groups': True},
 'score': {'sample_weight': False}}

注意

請注意,只要上述估計器未在元估計器中使用,使用者就不需要為中繼資料設定任何請求,並且設定的值會被忽略,因為消費者不會驗證或路由給定的中繼資料。上述估計器的簡單使用會按預期工作。

est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.

array([1., 1., 1.])

路由元估計器#

現在,我們展示如何設計一個元估計器作為路由器。作為一個簡化的範例,這是一個元估計器,除了路由中繼資料之外,它沒有做太多其他事情。

class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        # This method defines the routing for this meta-estimator.
        # In order to do so, a `MetadataRouter` instance is created, and the
        # routing is added to it. More explanations follow below.
        router = (owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=()
            .add(caller="fit", callee="fit")
            .add(caller="predict", callee="predict")
            .add(caller="score", callee="score"),
        )
        return router

    def fit(self, X, y, **fit_params):
        # `get_routing_for_object` returns a copy of the `MetadataRouter`
        # constructed by the above `get_metadata_routing` method, that is
        # internally called.
        request_router = (self)
        # Meta-estimators are responsible for validating the given metadata.
        # `method` refers to the parent's method, i.e. `fit` in this example.
        request_router.validate_metadata(params=fit_params, method="fit")
        # `MetadataRouter.route_params` maps the given metadata to the metadata
        # required by the underlying estimator based on the routing information
        # defined by the MetadataRouter. The output of type `Bunch` has a key
        # for each consuming object and those hold keys for their consuming
        # methods, which then contain key for the metadata which should be
        # routed to them.
        routed_params = request_router.route_params(params=fit_params, caller="fit")

        # A sub-estimator is fitted and its classes are attributed to the
        # meta-estimator.
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = (self)
        # then we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying `predict` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

讓我們分解上述程式碼的不同部分。

首先,get_routing_for_object 採用我們的元估計器 (self) 並傳回一個 MetadataRouter 或,如果物件是消費者,則傳回 MetadataRequest,基於估計器的 get_metadata_routing 方法的輸出。

然後在每個方法中,我們使用 route_params 方法來建構一個形式為 {"object_name": {"method_name": {"metadata": value}}} 的字典,以傳遞給基礎估計器的方法。 object_name (上述 routed_params.estimator.fit 範例中的 estimator) 與在 get_metadata_routing 中新增的名稱相同。 validate_metadata 確保請求所有給定的中繼資料,以避免靜默錯誤。

接下來,我們說明不同的行為,特別是引發的錯誤類型。

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
meta_est.fit(X, y, sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 加載此頁面。


請注意,上述範例是透過 ExampleClassifier 呼叫我們的工具函式 check_metadata()。它會檢查 sample_weight 是否正確傳遞給它。如果沒有,如下列範例所示,它會印出 sample_weightNone

meta_est.fit(X, y)
sample_weight is None in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 加載此頁面。


如果我們傳遞未知的元數據,則會引發錯誤

try:
    meta_est.fit(X, y, test=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object.

如果我們傳遞的元數據未明確要求

try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore.

此外,如果我們明確設定為不要求,但它卻被提供

meta_est = MetaClassifier(
    estimator=ExampleClassifier()
    .set_fit_request(sample_weight=True)
    .set_predict_request(groups=False)
)
try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object.

另一個要介紹的概念是別名元數據。當估算器要求的元數據的變數名稱與預設變數名稱不同時,就會發生這種情況。例如,在管道中有兩個估算器的情況下,一個估算器可能會要求 sample_weight1,另一個則要求 sample_weight2。請注意,這不會改變估算器的期望,它只會告訴元估算器如何將提供的元數據對應到所需的内容。這裡有一個範例,我們將 aliased_sample_weight 傳遞給元估算器,但元估算器理解 aliased_sample_weightsample_weight 的別名,並將其作為 sample_weight 傳遞給底層的估算器

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
meta_est.fit(X, y, aliased_sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 加載此頁面。


在此處傳遞 sample_weight 將會失敗,因為它是以別名請求的,而名稱為 sample_weight 的並未被請求。

try:
    meta_est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object.

這將引導我們了解 get_metadata_routing。scikit-learn 中路由的工作方式是,消費者請求他們需要的內容,而路由器則將其傳遞出去。此外,路由器會公開它本身的需求,以便可以在另一個路由器內部使用,例如,在網格搜尋物件內的管道。 get_metadata_routing 的輸出是 MetadataRouter 的字典表示形式,其中包含所有巢狀物件所請求的完整元數據樹狀結構,以及它們對應的方法路由,也就是說,子估算器的哪個方法會在元估算器的哪個方法中使用

print_routing(meta_est)
{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

如您所見,方法 fit 唯一請求的元數據是 "sample_weight",其別名為 "aliased_sample_weight"~utils.metadata_routing.MetadataRouter 類別讓我們能夠輕鬆建立路由物件,從而為我們的 get_metadata_routing 建立所需的輸出。

為了理解別名在元估算器中的工作方式,想像一下我們的元估算器在另一個元估算器中

meta_meta_est = MetaClassifier(estimator=meta_est).fit(
    X, y, aliased_sample_weight=my_weights
)
Received sample_weight of length = 100 in ExampleClassifier.

在上述範例中,這就是 meta_meta_estfit 方法如何呼叫其子估算器的 fit 方法

# user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:
meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
    ...

    # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`
    self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
        ...

        # the second sub-estimator (`est`) expects `sample_weight`
        self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
            ...

使用和路由元估算器#

對於稍微複雜一點的範例,請考慮一個元估算器,它像之前一樣將元數據路由到基礎估算器,但它也在自己的方法中使用一些元數據。這個元估算器同時是消費者和路由器。實作一個與我們之前的實作非常相似,但有一些調整。

class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            # defining metadata routing request values for usage in the meta-estimator
            .add_self_request(self)
            # defining metadata routing request values for usage in the sub-estimator
            .add(
                estimator=self.estimator,
                method_mapping=()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict")
                .add(caller="score", callee="score"),
            )
        )
        return router

    # Since `sample_weight` is used and consumed here, it should be defined as
    # an explicit argument in the method's signature. All other metadata which
    # are only routed, will be passed as `**fit_params`:
    def fit(self, X, y, sample_weight, **fit_params):
        if self.estimator is None:
            raise ValueError("estimator cannot be None!")

        check_metadata(self, sample_weight=sample_weight)

        # We add `sample_weight` to the `fit_params` dictionary.
        if sample_weight is not None:
            fit_params["sample_weight"] = sample_weight

        request_router = (self)
        request_router.validate_metadata(params=fit_params, method="fit")
        routed_params = request_router.route_params(params=fit_params, caller="fit")
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = (self)
        # we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying ``predict`` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

上述元估算器與我們之前的元估算器不同的關鍵部分是在 fit 中明確接受 sample_weight,並將其包含在 fit_params 中。由於 sample_weight 是一個明確的參數,我們可以確定此方法存在 set_fit_request(sample_weight=...)。這個元估算器同時是 sample_weight 的消費者和路由器。

get_metadata_routing 中,我們使用 add_self_requestself 新增到路由中,以指示這個估算器正在使用 sample_weight,並且也是一個路由器;這也會將 $self_request 鍵新增到路由資訊中,如下所示。現在讓我們來看一些範例

  • 未請求任何元數據

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • 子估算器請求 sample_weight

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': True},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • 元估算器請求 sample_weight

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
    sample_weight=True
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': True},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

請注意上述請求的元數據表示方式的差異。

  • 我們也可以為元數據建立別名,以便將不同的值傳遞給元估算器和子估算器的 fit 方法

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'clf_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

但是,元估算器的 fit 僅需要子估算器的別名,並將自己的樣本權重視為 sample_weight,因為它不會驗證和路由自己需要的元數據

meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
RouterConsumerClassifier(estimator=ExampleClassifier())
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 加載此頁面。


  • 僅在子估算器上建立別名

當我們不希望元估算器使用元數據,但子估算器應該使用時,這會很有用。

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

元估算器無法使用 aliased_sample_weight,因為它期望將其作為 sample_weight 傳遞。即使在其上設定了 set_fit_request(sample_weight=True),也會適用此情況。

簡單管道#

稍微複雜一點的使用案例是類似於 Pipeline 的元估算器。以下是一個元估算器,它接受轉換器和分類器。在呼叫其 fit 方法時,它會先套用轉換器的 fittransform,然後在轉換後的資料上執行分類器。在 predict 時,它會先套用轉換器的 transform,然後對轉換後的新資料使用分類器的 predict 方法進行預測。

class SimplePipeline(ClassifierMixin, BaseEstimator):
    def __init__(self, transformer, classifier):
        self.transformer = transformer
        self.classifier = classifier

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            # We add the routing for the transformer.
            .add(
                transformer=self.transformer,
                method_mapping=()
                # The metadata is routed such that it retraces how
                # `SimplePipeline` internally calls the transformer's `fit` and
                # `transform` methods in its own methods (`fit` and `predict`).
                .add(caller="fit", callee="fit")
                .add(caller="fit", callee="transform")
                .add(caller="predict", callee="transform"),
            )
            # We add the routing for the classifier.
            .add(
                classifier=self.classifier,
                method_mapping=()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict"),
            )
        )
        return router

    def fit(self, X, y, **fit_params):
        routed_params = (self, "fit", **fit_params)

        self.transformer_ = clone(self.transformer).fit(
            X, y, **routed_params.transformer.fit
        )
        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )

        self.classifier_ = clone(self.classifier).fit(
            X_transformed, y, **routed_params.classifier.fit
        )
        return self

    def predict(self, X, **predict_params):
        routed_params = (self, "predict", **predict_params)

        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )
        return self.classifier_.predict(
            X_transformed, **routed_params.classifier.predict
        )

請注意使用 MethodMapping 來宣告子估算器(被呼叫者)的哪些方法會在元估算器(呼叫者)的哪些方法中使用。如您所見,SimplePipelinefit 中使用轉換器的 transformfit 方法,並在 predict 中使用其 transform 方法,這就是您在管道類別的路由結構中所看到的實作。

上述範例與先前範例的另一個不同之處在於使用了 process_routing,它會處理輸入參數,執行必要的驗證,並傳回我們在先前的範例中建立的 routed_params。這減少了開發人員需要在每個元估算器的方法中編寫的樣板程式碼。強烈建議開發人員使用此函數,除非有充分的理由反對使用。

為了測試上述管道,讓我們新增一個範例轉換器。

class ExampleTransformer(TransformerMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def transform(self, X, groups=None):
        check_metadata(self, groups=groups)
        return X

    def fit_transform(self, X, y, sample_weight=None, groups=None):
        return self.fit(X, y, sample_weight).transform(X, groups)

請注意,在上述範例中,我們實作了 fit_transform,它會使用適當的元數據呼叫 fittransform。只有當 transform 接受元數據時,才需要這樣做,因為 TransformerMixin 中的預設 fit_transform 實作不會將元數據傳遞給 transform

現在我們可以測試我們的管道,看看元數據是否正確傳遞。這個範例使用我們的 SimplePipeline、我們的 ExampleTransformer 和我們的 RouterConsumerClassifier,後者使用我們的 ExampleClassifier

pipe = SimplePipeline(
    transformer=ExampleTransformer()
    # we set transformer's fit to receive sample_weight
    .set_fit_request(sample_weight=True)
    # we set transformer's transform to receive groups
    .set_transform_request(groups=True),
    classifier=RouterConsumerClassifier(
        estimator=ExampleClassifier()
        # we want this sub-estimator to receive sample_weight in fit
        .set_fit_request(sample_weight=True)
        # but not groups in predict
        .set_predict_request(groups=False),
    )
    # and we want the meta-estimator to receive sample_weight as well
    .set_fit_request(sample_weight=True),
)
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
    X[:3], groups=my_groups
)
Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.

array([1., 1., 1.])

棄用 / 預設值變更#

在本節中,我們將展示如何處理路由器同時成為消費者的情況,特別是當它消耗與其子估計器相同的元數據時,或者當消費者開始消耗舊版本中未消耗的元數據時。在這種情況下,應發出一段時間的警告,讓使用者知道行為已與先前版本不同。

class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, **fit_params):
        routed_params = (self, "fit", **fit_params)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = (owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=().add(caller="fit", callee="fit"),
        )
        return router

如上所述,如果 my_weights 不應作為 sample_weight 傳遞給 MetaRegressor,則這是有效的用法。

reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
reg.fit(X, y, sample_weight=my_weights)

現在想像我們進一步開發 MetaRegressor,它現在也消耗 sample_weight

class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    # show warning to remind user to explicitly set the value with
    # `.set_{method}_request(sample_weight={boolean})`
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, sample_weight=None, **fit_params):
        routed_params = (
            self, "fit", sample_weight=sample_weight, **fit_params
        )
        check_metadata(self, sample_weight=sample_weight)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            .add_self_request(self)
            .add(
                estimator=self.estimator,
                method_mapping=().add(caller="fit", callee="fit"),
            )
        )
        return router

上述實作幾乎與 MetaRegressor 相同,並且由於 __metadata_request__fit 中定義的預設請求值,在擬合時會發出警告。

with warnings.catch_warnings(record=True) as record:
    WeightedMetaRegressor(
        estimator=LinearRegression().set_fit_request(sample_weight=False)
    ).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
Received sample_weight of length = 100 in WeightedMetaRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

當估計器消耗之前未消耗的元數據時,可以使用以下模式來警告使用者。

class ExampleRegressor(RegressorMixin, BaseEstimator):
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def predict(self, X):
        return np.zeros(shape=(len(X)))


with warnings.catch_warnings(record=True) as record:
    MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
sample_weight is None in ExampleRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

最後,我們禁用元數據路由的配置標誌。

set_config(enable_metadata_routing=False)

第三方開發和 scikit-learn 依賴性#

如上所述,資訊是使用 MetadataRequestMetadataRouter 在類別之間傳遞的。強烈不建議,但如果您嚴格想要擁有一個與 scikit-learn 相容的估計器,而不依賴 scikit-learn 套件,則可以供應與元數據路由相關的工具。如果滿足以下所有條件,則您完全不需要修改程式碼

  • 您的估計器繼承自 BaseEstimator

  • 您的估計器方法(例如 fit)消耗的參數,在方法的簽名中明確定義,而不是 *args*kwargs

  • 您的估計器不會將任何元數據路由到基礎物件,即它不是路由器

腳本總執行時間:(0 分鐘 0.043 秒)

相關範例

使用預先計算的 Gram 矩陣和加權樣本擬合 Elastic Net

使用預先計算的 Gram 矩陣和加權樣本擬合 Elastic Net

scikit-learn 1.4 的發行重點

scikit-learn 1.4 的發行重點

SGD:加權樣本

SGD:加權樣本

scikit-learn 1.6 的發行重點

scikit-learn 1.6 的發行重點

由 Sphinx-Gallery 產生的圖庫