使用顯示物件進行視覺化#

在此範例中,我們將直接從它們各自的指標建構顯示物件,ConfusionMatrixDisplayRocCurveDisplayPrecisionRecallDisplay。當模型的預測已計算或計算成本高昂時,這是一種替代使用其對應繪圖函數的方法。請注意,這是進階用法,一般而言,我們建議使用其各自的繪圖函數。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

載入資料並訓練模型#

在此範例中,我們從 OpenML 載入一個輸血服務中心數據集。這是一個二元分類問題,其中目標是某人是否捐血。然後將資料分成訓練和測試數據集,並使用訓練數據集擬合邏輯回歸。

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

X, y = fetch_openml(data_id=1464, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)

clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)
Pipeline(steps=[('standardscaler', StandardScaler()),
                ('logisticregression', LogisticRegression(random_state=0))])
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 載入此頁面。


建立 ConfusionMatrixDisplay#

使用擬合模型,我們計算模型在測試數據集上的預測。這些預測用於計算混淆矩陣,該矩陣使用 ConfusionMatrixDisplay 繪製。

from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)

cm_display = ConfusionMatrixDisplay(cm).plot()
plot display object visualization

建立 RocCurveDisplay#

ROC 曲線需要估計器的機率或非閾值決策值。由於邏輯回歸提供決策函數,我們將使用它來繪製 ROC 曲線

from sklearn.metrics import RocCurveDisplay, roc_curve

y_score = clf.decision_function(X_test)

fpr, tpr, _ = roc_curve(y_test, y_score, pos_label=clf.classes_[1])
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
plot display object visualization
/home/circleci/project/sklearn/metrics/_plot/roc_curve.py:189: UserWarning:

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

建立 PrecisionRecallDisplay#

同樣地,可以使用先前章節的 y_score 繪製精確率-召回率曲線。

from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve

prec, recall, _ = precision_recall_curve(y_test, y_score, pos_label=clf.classes_[1])
pr_display = PrecisionRecallDisplay(precision=prec, recall=recall).plot()
plot display object visualization

將顯示物件合併到單一繪圖中#

顯示物件會儲存作為引數傳遞的計算值。這允許使用 matplotlib 的 API 輕鬆組合視覺化。在以下範例中,我們將顯示物件並排放在一行中。

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))

roc_display.plot(ax=ax1)
pr_display.plot(ax=ax2)
plt.show()
plot display object visualization
/home/circleci/project/sklearn/metrics/_plot/roc_curve.py:189: UserWarning:

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

腳本總執行時間: (0 分鐘 0.368 秒)

相關範例

精確率-召回率

精確率-召回率

具有視覺化 API 的 ROC 曲線

具有視覺化 API 的 ROC 曲線

多類別接收者操作特徵 (ROC)

多類別接收者操作特徵 (ROC)

事後調整成本敏感學習的決策閾值

事後調整成本敏感學習的決策閾值

由 Sphinx-Gallery 產生的圖庫