GroupKFold#
- class sklearn.model_selection.GroupKFold(n_splits=5, *, shuffle=False, random_state=None)[原始碼]#
具有不重疊群組的 K 折疊迭代器變體。
每個群組在所有折疊的測試集中只會出現一次(不同群組的數量必須至少等於折疊的數量)。
當
shuffle
為 True 時,折疊會大致平衡,因為每個測試折疊中的樣本數量大致相同。請在 使用者指南 中閱讀更多資訊。
關於交叉驗證行為的可視化以及常見 scikit-learn 分割方法之間的比較,請參考 在 scikit-learn 中可視化交叉驗證行為
- 參數:
- n_splitsint, default=5
折疊的數量。必須至少為 2。
版本變更自 0.22:
n_splits
的預設值從 3 變更為 5。- shufflebool, default=False
是否在分割成批次之前打亂群組。請注意,每個分割中的樣本不會被打亂。
新增於版本 1.6。
- random_stateint, RandomState 實例或 None, default=None
當
shuffle
為 True 時,random_state
會影響索引的順序,從而控制每個折疊的隨機性。否則,此參數無效。傳遞一個 int 以便在多次函數呼叫中獲得可重複的輸出。請參閱 詞彙表。新增於版本 1.6。
另請參閱
LeaveOneGroupOut
用於根據資料集的明確領域特定分層來分割資料。
StratifiedKFold
考慮類別資訊以避免建立具有不平衡類別比例的折疊(用於二元或多類別分類任務)。
註解
群組在整個折疊中以任意順序出現。
範例
>>> import numpy as np >>> from sklearn.model_selection import GroupKFold >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) >>> y = np.array([1, 2, 3, 4, 5, 6]) >>> groups = np.array([0, 0, 2, 2, 3, 3]) >>> group_kfold = GroupKFold(n_splits=2) >>> group_kfold.get_n_splits(X, y, groups) 2 >>> print(group_kfold) GroupKFold(n_splits=2, random_state=None, shuffle=False) >>> for i, (train_index, test_index) in enumerate(group_kfold.split(X, y, groups)): ... print(f"Fold {i}:") ... print(f" Train: index={train_index}, group={groups[train_index]}") ... print(f" Test: index={test_index}, group={groups[test_index]}") Fold 0: Train: index=[2 3], group=[2 2] Test: index=[0 1 4 5], group=[0 0 3 3] Fold 1: Train: index=[0 1 4 5], group=[0 0 3 3] Test: index=[2 3], group=[2 2]
- get_metadata_routing()[原始碼]#
取得此物件的中繼資料路由。
請檢查 使用者指南,了解路由機制如何運作。
- 回傳值:
- routingMetadataRequest
封裝路由資訊的
MetadataRequest
。
- get_n_splits(X=None, y=None, groups=None)[原始碼]#
傳回交叉驗證器中的分割迭代次數。
- 參數:
- Xobject
總是忽略,為相容性而存在。
- yobject
總是忽略,為相容性而存在。
- groupsobject
總是忽略,為相容性而存在。
- 回傳值:
- n_splitsint
傳回交叉驗證器中的分割迭代次數。
- set_split_request(*, groups: bool | None | str = '$UNCHANGED$') GroupKFold [原始碼]#
請求傳遞到
split
方法的中繼資料。請注意,只有在
enable_metadata_routing=True
時,此方法才相關(請參閱sklearn.set_config
)。請參閱 使用者指南,了解路由機制如何運作。每個參數的選項如下
True
:請求中繼資料,並在提供時傳遞到split
。如果未提供中繼資料,則會忽略請求。False
:不請求中繼資料,而且 meta-estimator 不會將其傳遞到split
。None
:不請求中繼資料,如果使用者提供中繼資料,meta-estimator 將會引發錯誤。str
:元數據應使用此別名傳遞給元估算器,而不是原始名稱。
預設值 (
sklearn.utils.metadata_routing.UNCHANGED
) 會保留現有的請求。這允許您更改某些參數的請求,而不更改其他參數的請求。在 1.3 版本中新增。
注意
僅當此估算器用作元估算器的子估算器時,此方法才相關,例如在
Pipeline
內部使用。否則,它沒有任何作用。- 參數:
- groupsstr、True、False 或 None,預設值=sklearn.utils.metadata_routing.UNCHANGED
split
中groups
參數的元數據路由。
- 回傳值:
- selfobject
更新後的物件。