使用繪圖 API 開發#
Scikit-learn 定義了一個簡單的 API,用於創建機器學習的可視化圖表。這個 API 的關鍵功能是運行一次計算,並在事後具有調整可視化圖表的彈性。本節適用於希望開發或維護繪圖工具的開發人員。關於使用方式,使用者應參考使用者指南。
繪圖 API 概覽#
此邏輯封裝在一個顯示物件中,其中儲存計算的資料,並且在 plot
方法中完成繪圖。顯示物件的 __init__
方法僅包含創建可視化圖表所需的資料。plot
方法接收僅與可視化圖表相關的參數,例如 matplotlib 軸。plot
方法將 matplotlib artists 儲存為屬性,允許通過顯示物件進行樣式調整。Display
類別應定義一個或兩個類別方法:from_estimator
和 from_predictions
。這些方法允許從估算器和一些資料或從真實值和預測值建立 Display
物件。在這些類別方法使用計算的值建立顯示物件後,則呼叫顯示的 plot 方法。請注意,plot
方法定義了與 matplotlib 相關的屬性,例如線條 artist。這允許在呼叫 plot
方法後進行自訂。
例如,RocCurveDisplay
定義了以下方法和屬性
class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
@classmethod
def from_estimator(cls, estimator, X, y):
# get the predictions
y_pred = estimator.predict_proba(X)[:, 1]
return cls.from_predictions(y, y_pred, estimator.__class__.__name__)
@classmethod
def from_predictions(cls, y, y_pred, estimator_name):
# do ROC computation from y and y_pred
fpr, tpr, roc_auc = ...
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
return viz.plot()
def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_
請在使用可視化 API 的 ROC 曲線和使用者指南中閱讀更多資訊。
使用多個軸繪圖#
一些繪圖工具(例如from_estimator
和PartialDependenceDisplay
)支援在多個軸上繪圖。支援兩種不同的情境
1. 如果傳入軸的清單,plot
將檢查軸的數量是否與其預期的軸數一致,然後在這些軸上繪圖。2. 如果傳入單個軸,則該軸定義要放置多個軸的空間。在這種情況下,我們建議使用 matplotlib 的 ~matplotlib.gridspec.GridSpecFromSubplotSpec
來分割空間
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec
fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())
ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])
預設情況下,plot
中的 ax
關鍵字為 None
。在這種情況下,將建立單個軸,並使用 gridspec api 來建立繪圖區域。
例如,請參閱from_estimator
,它使用此 API 繪製多條線條和輪廓。定義邊界框的軸儲存在 bounding_ax_
屬性中。建立的個別軸儲存在 axes_
ndarray 中,對應於網格上的軸位置。未使用的位置會設定為 None
。此外,matplotlib Artists 儲存在 lines_
和 contours_
中,其中索引鍵是網格上的位置。當傳入軸的清單時,axes_
、lines_
和 contours_
是一個 1d ndarray,對應於傳入的軸清單。