注意
前往結尾以下載完整的範例程式碼。或透過 JupyterLite 或 Binder 在您的瀏覽器中執行此範例
使用交叉驗證的接收者操作特徵 (ROC)#
此範例介紹如何使用交叉驗證來估計和視覺化接收者操作特徵 (ROC) 指標的變異數。
ROC 曲線通常在 Y 軸上呈現真陽性率 (TPR),而在 X 軸上呈現假陽性率 (FPR)。這表示圖表的左上角是「理想」點 - FPR 為零,TPR 為一。這不是很實際,但這確實表示曲線下面積 (AUC) 越大通常越好。ROC 曲線的「陡峭度」也很重要,因為最大化 TPR 同時最小化 FPR 是理想的。
此範例顯示由 K 折交叉驗證建立的不同資料集的 ROC 回應。取所有這些曲線,可以計算平均 AUC,並查看訓練集分割為不同子集時曲線的變異數。這大略顯示分類器輸出如何受訓練資料變更的影響,以及 K 折交叉驗證產生的不同分割彼此之間的差異程度。
注意
請參閱 多類別接收者操作特徵 (ROC),以補充本範例,說明將指標概括化為多類別分類器的平均策略。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
載入並準備資料#
我們匯入鳶尾花植物資料集,其中包含 3 個類別,每個類別對應於一種鳶尾花植物。一個類別與其他 2 個類別線性可分離;後者彼此不是線性可分離的。
在以下範例中,我們透過捨棄「virginica」類別 (class_id=2
) 來將資料集二值化。這表示「versicolor」類別 (class_id=1
) 被視為正類別,而「setosa」被視為負類別 (class_id=0
)。
我們還加入雜訊特徵,使問題更難。
random_state = np.random.RandomState(0)
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
分類和 ROC 分析#
這裡我們使用交叉驗證執行SVC
分類器,並逐折繪製 ROC 曲線。請注意,定義機率水準的基準 (虛線 ROC 曲線) 是一個始終預測最頻繁類別的分類器。
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.metrics import RocCurveDisplay, auc
from sklearn.model_selection import StratifiedKFold
n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = RocCurveDisplay.from_estimator(
classifier,
X[test],
y[test],
name=f"ROC fold {fold}",
alpha=0.3,
lw=1,
ax=ax,
plot_chance_level=(fold == n_splits - 1),
)
interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
ax.plot(
mean_fpr,
mean_tpr,
color="b",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
lw=2,
alpha=0.8,
)
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(
mean_fpr,
tprs_lower,
tprs_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
ax.set(
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
)
ax.legend(loc="lower right")
plt.show()

指令碼的總執行時間: (0 分鐘 0.194 秒)
相關範例