注意
前往結尾下載完整的範例程式碼。或透過 JupyterLite 或 Binder 在您的瀏覽器中執行此範例
辨識手寫數字#
此範例示範如何使用 scikit-learn 來辨識手寫數字影像,從 0 到 9。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
# Standard scientific Python imports
import matplotlib.pyplot as plt
# Import datasets, classifiers and performance metrics
from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split
數字資料集#
數字資料集由 8x8 像素的數字影像組成。資料集的 images
屬性儲存每個影像的 8x8 灰階值陣列。我們將使用這些陣列來視覺化前 4 個影像。資料集的 target
屬性儲存每個影像代表的數字,並包含在以下 4 個圖表的標題中。
注意:如果我們從影像檔案(例如「png」檔案)工作,我們將使用 matplotlib.pyplot.imread
來載入它們。
digits = datasets.load_digits()
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
ax.set_axis_off()
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title("Training: %i" % label)
data:image/s3,"s3://crabby-images/06aa7/06aa7a7282a4d937f368e3c0e9fcaf8caf7956cd" alt="Training: 0, Training: 1, Training: 2, Training: 3"
分類#
為了將分類器應用於此資料,我們需要將影像展平,將形狀為 (8, 8)
的每個 2 維灰階值陣列轉換為形狀為 (64,)
的陣列。隨後,整個資料集的形狀將為 (n_samples, n_features)
,其中 n_samples
是影像數量,而 n_features
是每個影像中的像素總數。
然後,我們可以將資料分割成訓練和測試子集,並在訓練樣本上擬合支援向量分類器。擬合的分類器隨後可用於預測測試子集中樣本的數字值。
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
# Create a classifier: a support vector classifier
clf = svm.SVC(gamma=0.001)
# Split data into 50% train and 50% test subsets
X_train, X_test, y_train, y_test = train_test_split(
data, digits.target, test_size=0.5, shuffle=False
)
# Learn the digits on the train subset
clf.fit(X_train, y_train)
# Predict the value of the digit on the test subset
predicted = clf.predict(X_test)
以下我們視覺化前 4 個測試樣本,並在標題中顯示它們的預測數字值。
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, prediction in zip(axes, X_test, predicted):
ax.set_axis_off()
image = image.reshape(8, 8)
ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
ax.set_title(f"Prediction: {prediction}")
data:image/s3,"s3://crabby-images/9a129/9a129231157cffcb9f0341832590451fe1892d33" alt="Prediction: 8, Prediction: 8, Prediction: 4, Prediction: 9"
classification_report
建構一個文字報告,顯示主要的分類指標。
print(
f"Classification report for classifier {clf}:\n"
f"{metrics.classification_report(y_test, predicted)}\n"
)
Classification report for classifier SVC(gamma=0.001):
precision recall f1-score support
0 1.00 0.99 0.99 88
1 0.99 0.97 0.98 91
2 0.99 0.99 0.99 86
3 0.98 0.87 0.92 91
4 0.99 0.96 0.97 92
5 0.95 0.97 0.96 91
6 0.99 0.99 0.99 91
7 0.96 0.99 0.97 89
8 0.94 1.00 0.97 88
9 0.93 0.98 0.95 92
accuracy 0.97 899
macro avg 0.97 0.97 0.97 899
weighted avg 0.97 0.97 0.97 899
我們也可以繪製真實數字值和預測數字值的混淆矩陣。
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)
disp.figure_.suptitle("Confusion Matrix")
print(f"Confusion matrix:\n{disp.confusion_matrix}")
plt.show()
data:image/s3,"s3://crabby-images/f95d0/f95d0c7f07e161d0e43eb7a6439d3eb356bf454a" alt="Confusion Matrix"
Confusion matrix:
[[87 0 0 0 1 0 0 0 0 0]
[ 0 88 1 0 0 0 0 0 1 1]
[ 0 0 85 1 0 0 0 0 0 0]
[ 0 0 0 79 0 3 0 4 5 0]
[ 0 0 0 0 88 0 0 0 0 4]
[ 0 0 0 0 0 88 1 0 0 2]
[ 0 1 0 0 0 0 90 0 0 0]
[ 0 0 0 0 0 1 0 88 0 0]
[ 0 0 0 0 0 0 0 0 88 0]
[ 0 0 0 1 0 1 0 0 0 90]]
如果評估分類器的結果以混淆矩陣的形式儲存,而不是以 y_true
和 y_pred
的形式儲存,仍然可以如下建構classification_report
# The ground truth and predicted lists
y_true = []
y_pred = []
cm = disp.confusion_matrix
# For each cell in the confusion matrix, add the corresponding ground truths
# and predictions to the lists
for gt in range(len(cm)):
for pred in range(len(cm)):
y_true += [gt] * cm[gt][pred]
y_pred += [pred] * cm[gt][pred]
print(
"Classification report rebuilt from confusion matrix:\n"
f"{metrics.classification_report(y_true, y_pred)}\n"
)
Classification report rebuilt from confusion matrix:
precision recall f1-score support
0 1.00 0.99 0.99 88
1 0.99 0.97 0.98 91
2 0.99 0.99 0.99 86
3 0.98 0.87 0.92 91
4 0.99 0.96 0.97 92
5 0.95 0.97 0.96 91
6 0.99 0.99 0.99 91
7 0.96 0.99 0.97 89
8 0.94 1.00 0.97 88
9 0.93 0.98 0.95 92
accuracy 0.97 899
macro avg 0.97 0.97 0.97 899
weighted avg 0.97 0.97 0.97 899
腳本的總執行時間: (0 分鐘 0.435 秒)
相關範例