9. 模型持久化#
持久化方法 |
優點 |
風險 / 缺點 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
在訓練 scikit-learn 模型後,需要有一種方法可以持久化模型以供日後使用,而無需重新訓練。根據您的使用案例,有幾種不同的方法可以持久化 scikit-learn 模型,在此我們幫助您決定哪一種最適合您。為了做出決定,您需要回答以下問題
您在持久化後是否需要 Python 物件,或者您只需要持久化才能提供模型並從中獲取預測?
如果您只需要提供模型,而不需要對 Python 物件本身進行進一步調查,那麼 ONNX 可能最適合您。請注意,並非所有模型都受 ONNX 支援。
如果 ONNX 不適合您的使用案例,則下一個問題是
您是否完全信任模型的來源,或者對於持久化模型的來源有任何安全性疑慮?
如果您有安全性疑慮,則應考慮使用 skops.io,它可以將 Python 物件返回給您,但與基於 pickle
的持久化解決方案不同,載入持久化模型不會自動允許執行任意程式碼。請注意,這需要手動檢查持久化檔案,skops.io
允許您執行此操作。
其他解決方案假設您完全信任要載入的檔案來源,因為它們都容易在載入持久化檔案時執行任意程式碼,因為它們都在底層使用 pickle 協議。
您是否關心載入模型的效能,以及在磁碟上的記憶體映射物件有利的情況下,在流程之間共享模型?
如果是,則可以考慮使用 joblib。如果這不是您主要關心的問題,則可以使用內建的 pickle
模組。
如果是,則可以使用 cloudpickle,它可以序列化某些 pickle
或 joblib
無法序列化的物件。
9.1. 工作流程概述#
在典型的工作流程中,第一步是使用 scikit-learn 和與 scikit-learn 相容的程式庫訓練模型。請注意,不同持久化方法對 scikit-learn 和第三方估算器的支援程度不同。
9.1.1. 訓練和持久化模型#
建立適當的模型取決於您的使用案例。例如,這裡我們在 iris 資料集上訓練 sklearn.ensemble.HistGradientBoostingClassifier
>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()
訓練好模型後,您可以使用所需的方法持久化模型,然後可以在單獨的環境中載入模型,並從輸入資料中取得預測。這裡有兩條主要路徑,取決於您如何持久化並計畫提供模型
ONNX:您需要一個
ONNX
執行階段和一個安裝了適當相依性的環境,以載入模型並使用執行階段來取得預測。此環境可以很小,甚至不需要安裝 Python 即可載入模型並計算預測。另請注意,與 Python 相比,onnxruntime
通常需要少得多的 RAM 來計算小型模型的預測。skops.io
、pickle
、joblib
、cloudpickle:您需要一個安裝了適當相依性的 Python 環境,才能載入模型並從中取得預測。此環境應具有與訓練模型時相同的套件和相同的版本。請注意,這些方法都不支援載入使用不同版本的 scikit-learn 訓練的模型,以及其他相依性(例如numpy
和scipy
)的不同版本。另一個需要考慮的問題是在不同的硬體上執行持久化模型,在大多數情況下,您應該能夠在不同的硬體上載入持久化模型。
9.2. ONNX#
ONNX
或 開放式神經網路交換格式最適合需要持久化模型,然後使用持久化成品來取得預測,而無需載入 Python 物件本身的使用案例。它也適用於需要精簡且最小的服務環境的情況,因為 ONNX
執行階段不需要 python
。
ONNX
是一種模型的二進制序列化格式。它的開發旨在提高資料模型互通表示的可用性。其目標是促進不同機器學習框架之間資料模型的轉換,並提高它們在不同計算架構上的可移植性。更多詳細資訊請參閱 ONNX 教學。為了將 scikit-learn 模型轉換為 ONNX
,開發了 sklearn-onnx。然而,並非所有 scikit-learn 模型都受支援,而且它僅限於核心 scikit-learn,不支援大多數第三方估計器。雖然可以為第三方或自訂估計器編寫自訂轉換器,但相關文件稀少,且可能具有挑戰性。
使用 ONNX#
要將模型轉換為 ONNX
格式,您還需要向轉換器提供一些關於輸入的資訊,您可以在 這裡 閱讀更多相關資訊。
from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
f.write(onx.SerializeToString())
您可以在 Python 中載入模型,並使用 ONNX
運行時取得預測結果。
from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
9.3. skops.io
#
skops.io
避免使用 pickle
,並且只載入具有預設或由使用者信任的類型和函式參照的檔案。因此,它提供比 pickle
、 joblib
和 cloudpickle 更安全的格式。
使用 skops#
API 與 pickle
非常相似,您可以按照 文件 中的說明,使用 skops.io.dump
和 skops.io.dumps
持久化您的模型。
import skops.io as sio
obj = sio.dump(clf, "filename.skops")
您可以使用 skops.io.load
和 skops.io.loads
將它們載入回來。但是,您需要指定您信任的類型。您可以使用 skops.io.get_untrusted_types
在已傾印的物件/檔案中取得現有的未知類型,並在檢查其內容後,將其傳遞給載入函式。
unknown_types = sio.get_untrusted_types(file="filename.skops")
# investigate the contents of unknown_types, and only load if you trust
# everything you see.
clf = sio.load("filename.skops", trusted=unknown_types)
請在 skops issue tracker 上回報與此格式相關的問題和功能請求。
9.4. pickle
、joblib
和 cloudpickle
#
這三個模組/套件在底層使用 pickle
協定,但帶有一些細微的變化。
pickle
是 Python 標準函式庫中的一個模組。它可以序列化和反序列化任何 Python 物件,包括自訂 Python 類別和物件。當處理大型機器學習模型或大型 numpy 陣列時,
joblib
比pickle
更有效率。cloudpickle 可以序列化某些
pickle
或joblib
無法序列化的物件,例如使用者定義的函式和 lambda 函式。例如,當使用FunctionTransformer
並使用自訂函式轉換資料時,可能會發生這種情況。
使用 pickle
、joblib
或 cloudpickle
#
根據您的使用案例,您可以選擇這三種方法之一來持久化和載入您的 scikit-learn 模型,它們都遵循相同的 API。
# Here you can replace pickle with joblib or cloudpickle
from pickle import dump
with open("filename.pkl", "wb") as f:
dump(clf, f, protocol=5)
建議使用 protocol=5
來減少記憶體使用量,並加快儲存和載入模型中作為已擬合屬性儲存的任何大型 NumPy 陣列的速度。或者,您可以傳遞 protocol=pickle.HIGHEST_PROTOCOL
,這在 Python 3.8 及更高版本中(撰寫本文時)等效於 protocol=5
。
稍後,當需要時,您可以從持久化檔案載入相同的物件。
# Here you can replace pickle with joblib or cloudpickle
from pickle import load
with open("filename.pkl", "rb") as f:
clf = load(f)
9.5. 安全性與可維護性限制#
pickle
(以及 joblib
和 clouldpickle
擴充)因其設計而具有許多已記錄的安全漏洞,只有當工件(即 pickle 檔案)來自受信任且經過驗證的來源時,才應使用它。您永遠不應從不受信任的來源載入 pickle 檔案,就像您永遠不應執行來自不受信任來源的程式碼一樣。
另請注意,可以使用 ONNX
格式表示任意計算,因此建議在沙箱環境中使用 ONNX
來提供模型服務,以防止計算和記憶體漏洞。
另請注意,沒有支援的方法可以載入使用不同版本的 scikit-learn 訓練的模型。雖然使用 skops.io
、joblib
、pickle
或 cloudpickle,使用某個版本的 scikit-learn 儲存的模型可能會在其他版本中載入,但是,這完全不受支援,也不建議這樣做。還應記住,對此類資料執行的操作可能會產生不同且意料之外的結果,甚至可能導致 Python 處理程序崩潰。
為了使用未來版本的 scikit-learn 重建類似的模型,應將額外的中繼資料與 pickling 模型一起儲存。
訓練資料,例如,不可變快照的參照。
用於產生模型的 Python 原始碼。
scikit-learn 及其相依性的版本。
在訓練資料上取得的交叉驗證分數。
這應能確保檢查交叉驗證分數是否與以前的範圍相同。
除了少數例外情況,假設使用相同版本的相依性和 Python,持久化模型應可在作業系統和硬體架構之間移植。如果您遇到無法移植的估計器,請在 GitHub 上開啟一個 issue。持久化模型通常會使用 Docker 等容器部署到生產環境中,以凍結環境和相依性。
如果您想了解更多關於這些問題的資訊,請參閱以下演講
9.5.1. 在生產環境中複製訓練環境#
如果使用的相依套件版本在訓練和生產環境中可能不同,可能會導致在使用已訓練模型時出現意外行為和錯誤。為了防止這種情況,建議在訓練和生產環境中使用相同的相依套件及其版本。這些傳遞相依性可以使用套件管理工具(如 pip
、mamba
、conda
、poetry
、conda-lock
、pixi
等)來固定。
在更新的軟體環境中載入使用較舊版本的 scikit-learn 函式庫及其相依套件訓練的模型並不總是可行。相反地,您可能需要使用所有函式庫的新版本重新訓練模型。因此,在訓練模型時,記錄訓練配方(例如 Python 腳本)和訓練集資訊,以及有關所有相依套件的中繼資料非常重要,以便能夠自動重建相同的訓練環境以適應更新的軟體。
InconsistentVersionWarning#
當使用與估算器 pickled 時的版本不一致的 scikit-learn 版本載入估算器時,會引發 InconsistentVersionWarning
。可以捕獲此警告以取得估算器 pickled 時的原始版本。
from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)
try:
with open("model_from_prevision_version.pickle", "rb") as f:
est = pickle.load(f)
except InconsistentVersionWarning as w:
print(w.original_sklearn_version)
9.5.2. 提供模型成品#
訓練 scikit-learn 模型後的最後一步是提供模型。成功載入已訓練的模型後,可以根據規範將其部署為 Web 服務,或使用其他模型部署策略,以管理不同的預測請求。
9.6. 重點總結#
根據不同的模型持久化方法,每種方法的重點可以總結如下:
ONNX
:它為持久化任何機器學習或深度學習模型(除了 scikit-learn)提供了一種統一的格式,並且對於模型推論(預測)非常有用。但是,它可能會導致與不同框架的相容性問題。skops.io
:可以使用skops.io
輕鬆共享已訓練的 scikit-learn 模型並將其投入生產。與基於pickle
的其他方法相比,它更安全,因為它不會載入任意程式碼,除非使用者明確要求。此類程式碼需要在目標 Python 環境中打包且可導入。joblib
:當使用mmap_mode="r"
在多個 Python 程序中使用相同的持久化模型時,有效的記憶體映射技術使其更快。它還提供了壓縮和解壓縮持久化物件的簡單捷徑,而無需額外的程式碼。但是,與任何其他基於 pickle 的持久化機制一樣,當從不受信任的來源載入模型時,它可能會觸發惡意程式碼的執行。pickle
:它是 Python 原生的,大多數 Python 物件都可以使用pickle
進行序列化和反序列化,包括自訂 Python 類別和函式,只要它們在可以導入到目標環境的套件中定義即可。雖然pickle
可以用來輕鬆儲存和載入 scikit-learn 模型,但從不受信任的來源載入模型時,它可能會觸發惡意程式碼的執行。pickle
如果模型是以protocol=5
持久化的,記憶體使用效率也很高,但不支援記憶體映射。cloudpickle:它具有與
pickle
和joblib
(不使用記憶體映射) 相近的載入效率,但提供了額外的彈性來序列化自訂 Python 程式碼,例如 lambda 表達式和互動式定義的函式和類別。它可能是持久化具有自訂 Python 組件的管道的最後手段,例如包裝在訓練腳本本身或更普遍地在任何可導入的 Python 套件之外定義的函式的sklearn.preprocessing.FunctionTransformer
。請注意,cloudpickle 不提供向前相容性保證,您可能需要相同版本的 cloudpickle 以及用於定義模型的所有函式庫的相同版本,才能載入持久化的模型。與其他基於 pickle 的持久化機制一樣,從不受信任的來源載入模型時,它可能會觸發惡意程式碼的執行。