取得 20 個新聞群組 (fetch_20newsgroups)#

sklearn.datasets.fetch_20newsgroups(*, data_home=None, subset='train', categories=None, shuffle=True, random_state=42, remove=(), download_if_missing=True, return_X_y=False, n_retries=3, delay=1.0)[來源]#

從 20 個新聞群組數據集(分類)載入檔案名稱和數據。

如有需要,請下載。

類別

20

總樣本數

18846

維度

1

特徵

文字

請在使用者指南中閱讀更多內容。

參數:
data_homestr 或 path-like,預設值為 None

指定數據集的下載和快取資料夾。如果為 None,則所有 scikit-learn 數據都將儲存在 ‘~/scikit_learn_data’ 子資料夾中。

subset{‘train’, ‘test’, ‘all’},預設值為 ‘train’

選擇要載入的數據集: ‘train’ 為訓練集,’test’ 為測試集,’all’ 為兩者,並以隨機順序排列。

categoriesarray-like, dtype=str, 預設值為 None

如果為 None (預設值),則載入所有類別。如果不是 None,則載入類別名稱列表 (其他類別將被忽略)。

shufflebool,預設值為 True

是否要對數據進行洗牌:對於假設樣本是獨立且均勻分佈(i.i.d.)的模型 (例如隨機梯度下降) 來說可能很重要。

random_stateint,RandomState 實例或 None,預設值為 42

決定數據集洗牌的隨機數生成。傳遞一個整數以在多個函數呼叫中產生可重複的輸出。請參閱詞彙表

removetuple,預設值為 ()

可以包含 (‘headers’, ‘footers’, ‘quotes’) 的任何子集。這些是將會從新聞群組貼文中偵測並移除的文字種類,以防止分類器在元數據上過度擬合。

‘headers’ 移除新聞群組標頭,‘footers’ 移除看起來像簽名的貼文末尾的區塊,而 ‘quotes’ 移除看起來是引用其他貼文的行。

‘headers’ 遵循精確的標準;其他篩選器並非總是正確。

download_if_missingbool,預設值為 True

如果為 False,則當數據在本機不可用時,將引發 OSError,而不是嘗試從來源網站下載數據。

return_X_ybool,預設值為 False

如果為 True,則返回 (data.data, data.target) 而不是 Bunch 物件。

在 0.22 版本中新增。

n_retriesint,預設值為 3

當遇到 HTTP 錯誤時的重試次數。

在 1.5 版本中新增。

delayfloat,預設值為 1.0

重試之間的秒數。

在 1.5 版本中新增。

返回:
bunchBunch

類似字典的物件,具有以下屬性。

data形狀為 (n_samples,) 的列表

要學習的數據列表。

target: 形狀為 (n_samples,) 的 ndarray

目標標籤。

filenames: 形狀為 (n_samples,) 的列表

數據位置的路徑。

DESCR: str

數據集的完整描述。

target_names: 形狀為 (n_classes,) 的列表

目標類別的名稱。

(data, target)如果 return_X_y=True,則為元組

一個包含兩個 ndarray 的元組。第一個包含形狀為 (n_samples, n_classes) 的二維陣列,其中每行代表一個樣本,每列代表一個特徵。第二個形狀為 (n_samples,) 的陣列包含目標樣本。

在 0.22 版本中新增。

範例

>>> from sklearn.datasets import fetch_20newsgroups
>>> cats = ['alt.atheism', 'sci.space']
>>> newsgroups_train = fetch_20newsgroups(subset='train', categories=cats)
>>> list(newsgroups_train.target_names)
['alt.atheism', 'sci.space']
>>> newsgroups_train.filenames.shape
(1073,)
>>> newsgroups_train.target.shape
(1073,)
>>> newsgroups_train.target[:10]
array([0, 1, 1, 1, 0, 1, 1, 0, 0, 0])