使用交叉驗證的接收者操作特徵 (ROC)#

此範例介紹如何使用交叉驗證來估計和視覺化接收者操作特徵 (ROC) 指標的變異數。

ROC 曲線通常在 Y 軸上呈現真陽性率 (TPR),而在 X 軸上呈現假陽性率 (FPR)。這表示圖表的左上角是「理想」點 - FPR 為零,TPR 為一。這不是很實際,但這確實表示曲線下面積 (AUC) 越大通常越好。ROC 曲線的「陡峭度」也很重要,因為最大化 TPR 同時最小化 FPR 是理想的。

此範例顯示由 K 折交叉驗證建立的不同資料集的 ROC 回應。取所有這些曲線,可以計算平均 AUC,並查看訓練集分割為不同子集時曲線的變異數。這大略顯示分類器輸出如何受訓練資料變更的影響,以及 K 折交叉驗證產生的不同分割彼此之間的差異程度。

注意

請參閱 多類別接收者操作特徵 (ROC),以補充本範例,說明將指標概括化為多類別分類器的平均策略。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

載入並準備資料#

我們匯入鳶尾花植物資料集,其中包含 3 個類別,每個類別對應於一種鳶尾花植物。一個類別與其他 2 個類別線性可分離;後者彼此不是線性可分離的。

在以下範例中,我們透過捨棄「virginica」類別 (class_id=2) 來將資料集二值化。這表示「versicolor」類別 (class_id=1) 被視為正類別,而「setosa」被視為負類別 (class_id=0)。

import numpy as np

from sklearn.datasets import load_iris

iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape

我們還加入雜訊特徵,使問題更難。

random_state = np.random.RandomState(0)
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)

分類和 ROC 分析#

這裡我們使用交叉驗證執行SVC分類器,並逐折繪製 ROC 曲線。請注意,定義機率水準的基準 (虛線 ROC 曲線) 是一個始終預測最頻繁類別的分類器。

import matplotlib.pyplot as plt

from sklearn import svm
from sklearn.metrics import RocCurveDisplay, auc
from sklearn.model_selection import StratifiedKFold

n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)

tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(X, y)):
    classifier.fit(X[train], y[train])
    viz = RocCurveDisplay.from_estimator(
        classifier,
        X[test],
        y[test],
        name=f"ROC fold {fold}",
        alpha=0.3,
        lw=1,
        ax=ax,
        plot_chance_level=(fold == n_splits - 1),
    )
    interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
    interp_tpr[0] = 0.0
    tprs.append(interp_tpr)
    aucs.append(viz.roc_auc)

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
ax.plot(
    mean_fpr,
    mean_tpr,
    color="b",
    label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
    lw=2,
    alpha=0.8,
)

std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(
    mean_fpr,
    tprs_lower,
    tprs_upper,
    color="grey",
    alpha=0.2,
    label=r"$\pm$ 1 std. dev.",
)

ax.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
)
ax.legend(loc="lower right")
plt.show()
Mean ROC curve with variability (Positive label 'versicolor')

指令碼的總執行時間: (0 分鐘 0.194 秒)

相關範例

多類別接收者操作特徵 (ROC)

多類別接收者操作特徵 (ROC)

使用視覺化 API 的 ROC 曲線

使用視覺化 API 的 ROC 曲線

檢測錯誤權衡 (DET) 曲線

檢測錯誤權衡 (DET) 曲線

使用顯示物件進行視覺化

使用顯示物件進行視覺化

由 Sphinx-Gallery 產生的圖庫