precision_recall_curve#

sklearn.metrics.precision_recall_curve(y_true, y_score=None, *, pos_label=None, sample_weight=None, drop_intermediate=False, probas_pred='deprecated')[原始碼]#

計算不同機率閾值下的精確率-召回率配對。

注意:此實作僅限於二元分類任務。

精確率是 tp / (tp + fp) 的比率,其中 tp 是真正例的數量,而 fp 是偽正例的數量。從直觀上來說,精確率是分類器不將負樣本標記為正樣本的能力。

召回率是 tp / (tp + fn) 的比率,其中 tp 是真正例的數量,而 fn 是偽負例的數量。從直觀上來說,召回率是分類器找到所有正樣本的能力。

最後一個精確率和召回率的值分別為 1. 和 0.,並且沒有對應的閾值。這確保了圖形從 y 軸開始。

第一個精確率和召回率的值分別為 precision=類別平衡和 recall=1.0,這對應於始終預測正類別的分類器。

使用者指南中閱讀更多資訊。

參數:
y_true形狀為 (n_samples,) 的類陣列

真實二元標籤。如果標籤不是 {-1, 1} 或 {0, 1},則應明確給出 pos_label。

y_score形狀為 (n_samples,) 的類陣列

目標分數,可以是正類別的機率估計值,也可以是不加閾值的決策衡量標準(由某些分類器的 decision_function 返回)。

pos_labelint、float、bool 或 str,預設為 None

正類別的標籤。當 pos_label=None 時,如果 y_true 在 {-1, 1} 或 {0, 1} 中,則 pos_label 設定為 1,否則將引發錯誤。

sample_weight形狀為 (n_samples,) 的類陣列,預設為 None

樣本權重。

drop_intermediatebool,預設為 False

是否捨棄一些次優的閾值,這些閾值不會出現在繪製的精確率-召回率曲線上。這對於建立更輕量的精確率-召回率曲線很有用。

在 1.3 版本中新增。

probas_pred形狀為 (n_samples,) 的類陣列

目標分數,可以是正類別的機率估計值,也可以是不加閾值的決策衡量標準(由某些分類器的 decision_function 返回)。

自 1.5 版本起已棄用:probas_pred 已棄用,將在 1.7 版本中移除。請改用 y_score

返回:
precision形狀為 (n_thresholds + 1,) 的 ndarray

精確率值,使得元素 i 是分數 >= thresholds[i] 的預測的精確率,最後一個元素是 1。

recall形狀為 (n_thresholds + 1,) 的 ndarray

遞減的召回率值,使得元素 i 是分數 >= thresholds[i] 的預測的召回率,最後一個元素是 0。

thresholds形狀為 (n_thresholds,) 的 ndarray

用於計算精確率和召回率的決策函數上的遞增閾值,其中 n_thresholds = len(np.unique(probas_pred))

另請參閱

PrecisionRecallDisplay.from_estimator

繪製給定二元分類器的精確率-召回率曲線。

PrecisionRecallDisplay.from_predictions

使用二元分類器的預測繪製精確率-召回率曲線。

average_precision_score

average_precision_score

det_curve

從預測分數計算平均精確率。

roc_curve

roc_curve

範例

>>> import numpy as np
>>> from sklearn.metrics import precision_recall_curve
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> precision, recall, thresholds = precision_recall_curve(
...     y_true, y_scores)
>>> precision
array([0.5       , 0.66666667, 0.5       , 1.        , 1.        ])
>>> recall
array([1. , 1. , 0.5, 0.5, 0. ])
>>> thresholds
array([0.1 , 0.35, 0.4 , 0.8 ])
使用顯示物件進行視覺化