分層群組 K 折 (StratifiedGroupKFold)#
- class sklearn.model_selection.StratifiedGroupKFold(n_splits=5, shuffle=False, random_state=None)[原始碼]#
具有不重疊群組的分層 K 折疊迭代器變體。
此交叉驗證物件是 StratifiedKFold 的變體,嘗試返回具有不重疊群組的分層折疊。 折疊是通過保留每個類別的樣本百分比來完成的。
每個群組在所有折疊的測試集中只會出現一次(不同群組的數量必須至少等於折疊的數量)。
GroupKFold
和StratifiedGroupKFold
之間的差異在於,前者嘗試建立平衡的折疊,使每個折疊中不同群組的數量大致相同,而StratifiedGroupKFold
嘗試建立折疊,在分割之間不重疊群組的約束下,盡可能保留每個類別的樣本百分比。請參閱使用者指南了解更多資訊。
如需視覺化交叉驗證行為和比較常見的 scikit-learn 分割方法,請參閱在 scikit-learn 中視覺化交叉驗證行為
- 參數:
- n_splitsint,預設值=5
折疊數。必須至少為 2。
- shufflebool,預設值=False
在分割成批次之前,是否打亂每個類別的樣本。 請注意,每個分割中的樣本不會被洗牌。 此實作只能洗牌具有大致相同 y 分佈的群組,不會執行全域洗牌。
- random_stateint 或 RandomState 實例,預設值=None
當
shuffle
為 True 時,random_state
會影響索引的排序,進而控制每個類別每個折疊的隨機性。 否則,請將random_state
保留為None
。 傳遞 int 以在多個函數呼叫中產生可重複的輸出。 請參閱詞彙表。
另請參閱
StratifiedKFold
考慮類別資訊以建立保留類別分佈的折疊(用於二元或多類別分類任務)。
GroupKFold
具有不重疊群組的 K 折疊迭代器變體。
注意事項
實作旨在
對於微不足道的群組(例如,當每個群組僅包含一個樣本時),盡可能模擬 StratifiedKFold 的行為。
不受類別標籤影響:將
y = ["Happy", "Sad"]
重新標記為y = [1, 0]
不應變更產生的索引。盡可能根據樣本進行分層,同時保留不重疊群組的約束。 這表示在某些情況下,當存在少量包含大量樣本的群組時,將無法進行分層,並且其行為將接近 GroupKFold。
範例
>>> import numpy as np >>> from sklearn.model_selection import StratifiedGroupKFold >>> X = np.ones((17, 2)) >>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) >>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8]) >>> sgkf = StratifiedGroupKFold(n_splits=3) >>> sgkf.get_n_splits(X, y) 3 >>> print(sgkf) StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False) >>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)): ... print(f"Fold {i}:") ... print(f" Train: index={train_index}") ... print(f" group={groups[train_index]}") ... print(f" Test: index={test_index}") ... print(f" group={groups[test_index]}") Fold 0: Train: index=[ 0 1 2 3 7 8 9 10 11 15 16] group=[1 1 2 2 4 5 5 5 5 8 8] Test: index=[ 4 5 6 12 13 14] group=[3 3 3 6 6 7] Fold 1: Train: index=[ 4 5 6 7 8 9 10 11 12 13 14] group=[3 3 3 4 5 5 5 5 6 6 7] Test: index=[ 0 1 2 3 15 16] group=[1 1 2 2 8 8] Fold 2: Train: index=[ 0 1 2 3 4 5 6 12 13 14 15 16] group=[1 1 2 2 3 3 3 6 6 7 8 8] Test: index=[ 7 8 9 10 11] group=[4 5 5 5 5]
- get_metadata_routing()[原始碼]#
取得此物件的中繼資料路由。
請查看使用者指南,了解路由機制如何運作。
- 傳回:
- routingMetadataRequest
封裝路由資訊的
MetadataRequest
。
- get_n_splits(X=None, y=None, groups=None)[原始碼]#
傳回交叉驗證器中的分割迭代次數。
- 參數:
- X物件
始終忽略,存在以實現相容性。
- y物件
始終忽略,存在以實現相容性。
- groups物件
始終忽略,存在以實現相容性。
- 傳回:
- n_splitsint
傳回交叉驗證器中的分割迭代次數。
- set_split_request(*, groups: bool | None | str = '$UNCHANGED$') StratifiedGroupKFold [原始碼]#
要求傳遞至
split
方法的中繼資料。請注意,只有在
enable_metadata_routing=True
時,此方法才相關(請參閱sklearn.set_config
)。 請參閱使用者指南,了解路由機制如何運作。每個參數的選項是
True
:會請求中繼資料,如果提供,則會傳遞至split
。 如果未提供中繼資料,則會忽略此要求。False
:不會請求中繼資料,且元估計器不會將其傳遞至split
。None
:不請求元數據,如果使用者提供元數據,則元估計器會引發錯誤。str
:元數據應使用此給定的別名,而不是原始名稱傳遞給元估計器。
預設值(
sklearn.utils.metadata_routing.UNCHANGED
)保留現有的請求。這允許您變更某些參數的請求,而不變更其他參數的請求。在 1.3 版本中新增。
注意
只有當此估計器作為元估計器的子估計器使用時,此方法才相關,例如在
Pipeline
中使用。否則,它不會產生任何影響。- 參數:
- groupsstr、True、False 或 None,預設值為 sklearn.utils.metadata_routing.UNCHANGED
split
中groups
參數的元數據路由。
- 傳回:
- self物件
更新後的物件。