普通最小平方範例#

此範例示範如何使用 scikit-learn 中名為 LinearRegression 的普通最小平方 (OLS) 模型。

為此,我們使用糖尿病數據集中的單一特徵,並嘗試使用此線性模型來預測糖尿病進展。因此,我們載入糖尿病數據集並將其拆分為訓練集和測試集。

然後,我們在訓練集上擬合模型,並評估其在測試集上的效能,最後視覺化測試集上的結果。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

數據載入和準備#

載入糖尿病數據集。為了簡單起見,我們只在資料中保留單一特徵。然後,我們將數據和目標拆分為訓練集和測試集。

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

X, y = load_diabetes(return_X_y=True)
X = X[:, [2]]  # Use only one feature
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20, shuffle=False)

線性回歸模型#

我們建立線性回歸模型並將其擬合到訓練數據上。請注意,預設情況下,模型會新增一個截距。我們可以透過設定 fit_intercept 參數來控制此行為。

from sklearn.linear_model import LinearRegression

regressor = LinearRegression().fit(X_train, y_train)

模型評估#

我們使用均方誤差和判定係數來評估模型在測試集上的效能。

from sklearn.metrics import mean_squared_error, r2_score

y_pred = regressor.predict(X_test)

print(f"Mean squared error: {mean_squared_error(y_test, y_pred):.2f}")
print(f"Coefficient of determination: {r2_score(y_test, y_pred):.2f}")
Mean squared error: 2548.07
Coefficient of determination: 0.47

繪製結果#

最後,我們視覺化訓練和測試數據的結果。

import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2, figsize=(10, 5), sharex=True, sharey=True)

ax[0].scatter(X_train, y_train, label="Train data points")
ax[0].plot(
    X_train,
    regressor.predict(X_train),
    linewidth=3,
    color="tab:orange",
    label="Model predictions",
)
ax[0].set(xlabel="Feature", ylabel="Target", title="Train set")
ax[0].legend()

ax[1].scatter(X_test, y_test, label="Test data points")
ax[1].plot(X_test, y_pred, linewidth=3, color="tab:orange", label="Model predictions")
ax[1].set(xlabel="Feature", ylabel="Target", title="Test set")
ax[1].legend()

fig.suptitle("Linear Regression")

plt.show()
Linear Regression, Train set, Test set

結論#

經過訓練的模型對應於最小化訓練數據上預測值和真實目標值之間均方誤差的估算器。因此,我們獲得了給定數據的目標條件均值的估算器。

請注意,在較高的維度中,僅最小化平方誤差可能會導致過擬合。因此,通常會使用正規化技術來防止此問題,例如在 RidgeLasso 中實作的技術。

腳本的總執行時間: (0 分鐘 0.205 秒)

相關範例

繪製個別和投票迴歸預測

繪製個別和投票迴歸預測

繪製交叉驗證預測

繪製交叉驗證預測

非負最小平方

非負最小平方

梯度提升迴歸

梯度提升迴歸

由 Sphinx-Gallery 產生的圖庫