GroupKFold#

class sklearn.model_selection.GroupKFold(n_splits=5, *, shuffle=False, random_state=None)[原始碼]#

具有不重疊群組的 K 折疊迭代器變體。

每個群組在所有折疊的測試集中只會出現一次(不同群組的數量必須至少等於折疊的數量)。

shuffle 為 True 時,折疊會大致平衡,因為每個測試折疊中的樣本數量大致相同。

請在 使用者指南 中閱讀更多資訊。

關於交叉驗證行為的可視化以及常見 scikit-learn 分割方法之間的比較,請參考 在 scikit-learn 中可視化交叉驗證行為

參數:
n_splitsint, default=5

折疊的數量。必須至少為 2。

版本變更自 0.22: n_splits 的預設值從 3 變更為 5。

shufflebool, default=False

是否在分割成批次之前打亂群組。請注意,每個分割中的樣本不會被打亂。

新增於版本 1.6。

random_stateint, RandomState 實例或 None, default=None

shuffle 為 True 時,random_state 會影響索引的順序,從而控制每個折疊的隨機性。否則,此參數無效。傳遞一個 int 以便在多次函數呼叫中獲得可重複的輸出。請參閱 詞彙表

新增於版本 1.6。

另請參閱

LeaveOneGroupOut

用於根據資料集的明確領域特定分層來分割資料。

StratifiedKFold

考慮類別資訊以避免建立具有不平衡類別比例的折疊(用於二元或多類別分類任務)。

註解

群組在整個折疊中以任意順序出現。

範例

>>> import numpy as np
>>> from sklearn.model_selection import GroupKFold
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
>>> y = np.array([1, 2, 3, 4, 5, 6])
>>> groups = np.array([0, 0, 2, 2, 3, 3])
>>> group_kfold = GroupKFold(n_splits=2)
>>> group_kfold.get_n_splits(X, y, groups)
2
>>> print(group_kfold)
GroupKFold(n_splits=2, random_state=None, shuffle=False)
>>> for i, (train_index, test_index) in enumerate(group_kfold.split(X, y, groups)):
...     print(f"Fold {i}:")
...     print(f"  Train: index={train_index}, group={groups[train_index]}")
...     print(f"  Test:  index={test_index}, group={groups[test_index]}")
Fold 0:
  Train: index=[2 3], group=[2 2]
  Test:  index=[0 1 4 5], group=[0 0 3 3]
Fold 1:
  Train: index=[0 1 4 5], group=[0 0 3 3]
  Test:  index=[2 3], group=[2 2]
get_metadata_routing()[原始碼]#

取得此物件的中繼資料路由。

請檢查 使用者指南,了解路由機制如何運作。

回傳值:
routingMetadataRequest

封裝路由資訊的 MetadataRequest

get_n_splits(X=None, y=None, groups=None)[原始碼]#

傳回交叉驗證器中的分割迭代次數。

參數:
Xobject

總是忽略,為相容性而存在。

yobject

總是忽略,為相容性而存在。

groupsobject

總是忽略,為相容性而存在。

回傳值:
n_splitsint

傳回交叉驗證器中的分割迭代次數。

set_split_request(*, groups: bool | None | str = '$UNCHANGED$') GroupKFold[原始碼]#

請求傳遞到 split 方法的中繼資料。

請注意,只有在 enable_metadata_routing=True 時,此方法才相關(請參閱 sklearn.set_config)。請參閱 使用者指南,了解路由機制如何運作。

每個參數的選項如下

  • True:請求中繼資料,並在提供時傳遞到 split。如果未提供中繼資料,則會忽略請求。

  • False:不請求中繼資料,而且 meta-estimator 不會將其傳遞到 split

  • None:不請求中繼資料,如果使用者提供中繼資料,meta-estimator 將會引發錯誤。

  • str:元數據應使用此別名傳遞給元估算器,而不是原始名稱。

預設值 (sklearn.utils.metadata_routing.UNCHANGED) 會保留現有的請求。這允許您更改某些參數的請求,而不更改其他參數的請求。

在 1.3 版本中新增。

注意

僅當此估算器用作元估算器的子估算器時,此方法才相關,例如在 Pipeline 內部使用。否則,它沒有任何作用。

參數:
groupsstr、True、False 或 None,預設值=sklearn.utils.metadata_routing.UNCHANGED

splitgroups 參數的元數據路由。

回傳值:
selfobject

更新後的物件。

split(X, y=None, groups=None)[原始碼]#

產生將資料分割成訓練集和測試集的索引。

參數:
X形狀為 (n_samples, n_features) 的類陣列

訓練資料,其中 n_samples 是樣本數,n_features 是特徵數。

y形狀為 (n_samples,) 的類陣列,預設值=None

監督式學習問題的目標變數。

groups形狀為 (n_samples,) 的類陣列

在將資料集分割成訓練/測試集時使用的樣本組標籤。

產生值:
trainndarray

該分割的訓練集索引。

testndarray

該分割的測試集索引。