使用繪圖 API 開發#

Scikit-learn 定義了一個簡單的 API,用於創建機器學習的可視化圖表。這個 API 的關鍵功能是運行一次計算,並在事後具有調整可視化圖表的彈性。本節適用於希望開發或維護繪圖工具的開發人員。關於使用方式,使用者應參考使用者指南

繪圖 API 概覽#

此邏輯封裝在一個顯示物件中,其中儲存計算的資料,並且在 plot 方法中完成繪圖。顯示物件的 __init__ 方法僅包含創建可視化圖表所需的資料。plot 方法接收僅與可視化圖表相關的參數,例如 matplotlib 軸。plot 方法將 matplotlib artists 儲存為屬性,允許通過顯示物件進行樣式調整。Display 類別應定義一個或兩個類別方法:from_estimatorfrom_predictions。這些方法允許從估算器和一些資料或從真實值和預測值建立 Display 物件。在這些類別方法使用計算的值建立顯示物件後,則呼叫顯示的 plot 方法。請注意,plot 方法定義了與 matplotlib 相關的屬性,例如線條 artist。這允許在呼叫 plot 方法後進行自訂。

例如,RocCurveDisplay 定義了以下方法和屬性

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
        self.fpr = fpr
        self.tpr = tpr
        self.roc_auc = roc_auc
        self.estimator_name = estimator_name

    @classmethod
    def from_estimator(cls, estimator, X, y):
        # get the predictions
        y_pred = estimator.predict_proba(X)[:, 1]
        return cls.from_predictions(y, y_pred, estimator.__class__.__name__)

    @classmethod
    def from_predictions(cls, y, y_pred, estimator_name):
        # do ROC computation from y and y_pred
        fpr, tpr, roc_auc = ...
        viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
        return viz.plot()

    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

請在使用可視化 API 的 ROC 曲線使用者指南中閱讀更多資訊。

使用多個軸繪圖#

一些繪圖工具(例如from_estimatorPartialDependenceDisplay)支援在多個軸上繪圖。支援兩種不同的情境

1. 如果傳入軸的清單,plot 將檢查軸的數量是否與其預期的軸數一致,然後在這些軸上繪圖。2. 如果傳入單個軸,則該軸定義要放置多個軸的空間。在這種情況下,我們建議使用 matplotlib 的 ~matplotlib.gridspec.GridSpecFromSubplotSpec 來分割空間

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec

fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())

ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])

預設情況下,plot 中的 ax 關鍵字為 None。在這種情況下,將建立單個軸,並使用 gridspec api 來建立繪圖區域。

例如,請參閱from_estimator,它使用此 API 繪製多條線條和輪廓。定義邊界框的軸儲存在 bounding_ax_ 屬性中。建立的個別軸儲存在 axes_ ndarray 中,對應於網格上的軸位置。未使用的位置會設定為 None。此外,matplotlib Artists 儲存在 lines_contours_ 中,其中索引鍵是網格上的位置。當傳入軸的清單時,axes_lines_contours_ 是一個 1d ndarray,對應於傳入的軸清單。