分層群組 K 折 (StratifiedGroupKFold)#

class sklearn.model_selection.StratifiedGroupKFold(n_splits=5, shuffle=False, random_state=None)[原始碼]#

具有不重疊群組的分層 K 折疊迭代器變體。

此交叉驗證物件是 StratifiedKFold 的變體,嘗試返回具有不重疊群組的分層折疊。 折疊是通過保留每個類別的樣本百分比來完成的。

每個群組在所有折疊的測試集中只會出現一次(不同群組的數量必須至少等於折疊的數量)。

GroupKFoldStratifiedGroupKFold 之間的差異在於,前者嘗試建立平衡的折疊,使每個折疊中不同群組的數量大致相同,而 StratifiedGroupKFold 嘗試建立折疊,在分割之間不重疊群組的約束下,盡可能保留每個類別的樣本百分比。

請參閱使用者指南了解更多資訊。

如需視覺化交叉驗證行為和比較常見的 scikit-learn 分割方法,請參閱在 scikit-learn 中視覺化交叉驗證行為

參數:
n_splitsint,預設值=5

折疊數。必須至少為 2。

shufflebool,預設值=False

在分割成批次之前,是否打亂每個類別的樣本。 請注意,每個分割中的樣本不會被洗牌。 此實作只能洗牌具有大致相同 y 分佈的群組,不會執行全域洗牌。

random_stateint 或 RandomState 實例,預設值=None

shuffle 為 True 時,random_state 會影響索引的排序,進而控制每個類別每個折疊的隨機性。 否則,請將 random_state 保留為 None。 傳遞 int 以在多個函數呼叫中產生可重複的輸出。 請參閱詞彙表

另請參閱

StratifiedKFold

考慮類別資訊以建立保留類別分佈的折疊(用於二元或多類別分類任務)。

GroupKFold

具有不重疊群組的 K 折疊迭代器變體。

注意事項

實作旨在

  • 對於微不足道的群組(例如,當每個群組僅包含一個樣本時),盡可能模擬 StratifiedKFold 的行為。

  • 不受類別標籤影響:將 y = ["Happy", "Sad"] 重新標記為 y = [1, 0] 不應變更產生的索引。

  • 盡可能根據樣本進行分層,同時保留不重疊群組的約束。 這表示在某些情況下,當存在少量包含大量樣本的群組時,將無法進行分層,並且其行為將接近 GroupKFold。

範例

>>> import numpy as np
>>> from sklearn.model_selection import StratifiedGroupKFold
>>> X = np.ones((17, 2))
>>> y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 8, 8])
>>> sgkf = StratifiedGroupKFold(n_splits=3)
>>> sgkf.get_n_splits(X, y)
3
>>> print(sgkf)
StratifiedGroupKFold(n_splits=3, random_state=None, shuffle=False)
>>> for i, (train_index, test_index) in enumerate(sgkf.split(X, y, groups)):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}")
...     print(f"         group={groups[train_index]}")
...     print(f"  Test:  index={test_index}")
...     print(f"         group={groups[test_index]}")
Fold 0:
  Train: index=[ 0  1  2  3  7  8  9 10 11 15 16]
         group=[1 1 2 2 4 5 5 5 5 8 8]
  Test:  index=[ 4  5  6 12 13 14]
         group=[3 3 3 6 6 7]
Fold 1:
  Train: index=[ 4  5  6  7  8  9 10 11 12 13 14]
         group=[3 3 3 4 5 5 5 5 6 6 7]
  Test:  index=[ 0  1  2  3 15 16]
         group=[1 1 2 2 8 8]
Fold 2:
  Train: index=[ 0  1  2  3  4  5  6 12 13 14 15 16]
         group=[1 1 2 2 3 3 3 6 6 7 8 8]
  Test:  index=[ 7  8  9 10 11]
         group=[4 5 5 5 5]
get_metadata_routing()[原始碼]#

取得此物件的中繼資料路由。

請查看使用者指南,了解路由機制如何運作。

傳回:
routingMetadataRequest

封裝路由資訊的 MetadataRequest

get_n_splits(X=None, y=None, groups=None)[原始碼]#

傳回交叉驗證器中的分割迭代次數。

參數:
X物件

始終忽略,存在以實現相容性。

y物件

始終忽略,存在以實現相容性。

groups物件

始終忽略,存在以實現相容性。

傳回:
n_splitsint

傳回交叉驗證器中的分割迭代次數。

set_split_request(*, groups: bool | None | str = '$UNCHANGED$') StratifiedGroupKFold[原始碼]#

要求傳遞至 split 方法的中繼資料。

請注意,只有在 enable_metadata_routing=True 時,此方法才相關(請參閱 sklearn.set_config)。 請參閱使用者指南,了解路由機制如何運作。

每個參數的選項是

  • True:會請求中繼資料,如果提供,則會傳遞至 split。 如果未提供中繼資料,則會忽略此要求。

  • False:不會請求中繼資料,且元估計器不會將其傳遞至 split

  • None:不請求元數據,如果使用者提供元數據,則元估計器會引發錯誤。

  • str:元數據應使用此給定的別名,而不是原始名稱傳遞給元估計器。

預設值(sklearn.utils.metadata_routing.UNCHANGED)保留現有的請求。這允許您變更某些參數的請求,而不變更其他參數的請求。

在 1.3 版本中新增。

注意

只有當此估計器作為元估計器的子估計器使用時,此方法才相關,例如在 Pipeline 中使用。否則,它不會產生任何影響。

參數:
groupsstr、True、False 或 None,預設值為 sklearn.utils.metadata_routing.UNCHANGED

splitgroups 參數的元數據路由。

傳回:
self物件

更新後的物件。

split(X, y=None, groups=None)[原始碼]#

產生將資料分割成訓練集和測試集的索引。

參數:
X形狀為 (n_samples, n_features) 的類陣列

訓練資料,其中 n_samples 是樣本數,而 n_features 是特徵數。

y形狀為 (n_samples,) 的類陣列,預設值為 None

用於監督式學習問題的目標變數。

groups形狀為 (n_samples,) 的類陣列,預設值為 None

將資料集分割成訓練集/測試集時使用的樣本群組標籤。

產生:
trainndarray

該分割的訓練集索引。

testndarray

該分割的測試集索引。