注意
前往末尾下載完整的範例程式碼。 或透過 JupyterLite 或 Binder 在您的瀏覽器中執行此範例
GMM 初始化方法#
高斯混合模型中不同初始化方法的範例
請參閱 高斯混合模型 以取得有關估計器的更多資訊。
在此,我們生成一些具有四個容易識別的叢集的範例資料。 此範例的目的是展示初始化參數 *init_param* 的四種不同方法。
四個初始化是 *kmeans* (預設)、*random*、*random_from_data* 和 *k-means++*。
橙色菱形表示由 *init_param* 產生的 gmm 的初始化中心。 其餘資料表示為十字,顏色代表 GMM 完成後最終關聯的分類。
每個子圖的右上角的數字表示 GaussianMixture 收斂所花費的迭代次數,以及演算法初始化部分執行的相對時間。 較短的初始化時間往往需要更多的迭代才能收斂。
初始化時間是該方法所花費的時間與預設 *kmeans* 方法所花費的時間之比。 如您所見,與 *kmeans* 相比,所有三種替代方法都花費較少的時間初始化。
在此範例中,當使用 *random_from_data* 或 *random* 初始化時,模型需要更多迭代才能收斂。 在此,*k-means++* 在初始化時間短和 GaussianMixture 收斂的迭代次數少方面都做得很好。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from timeit import default_timer as timer
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets._samples_generator import make_blobs
from sklearn.mixture import GaussianMixture
from sklearn.utils.extmath import row_norms
print(__doc__)
# Generate some data
X, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0)
X = X[:, ::-1]
n_samples = 4000
n_components = 4
x_squared_norms = row_norms(X, squared=True)
def get_initial_means(X, init_params, r):
# Run a GaussianMixture with max_iter=0 to output the initialization means
gmm = GaussianMixture(
n_components=4, init_params=init_params, tol=1e-9, max_iter=0, random_state=r
).fit(X)
return gmm.means_
methods = ["kmeans", "random_from_data", "k-means++", "random"]
colors = ["navy", "turquoise", "cornflowerblue", "darkorange"]
times_init = {}
relative_times = {}
plt.figure(figsize=(4 * len(methods) // 2, 6))
plt.subplots_adjust(
bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95
)
for n, method in enumerate(methods):
r = np.random.RandomState(seed=1234)
plt.subplot(2, len(methods) // 2, n + 1)
start = timer()
ini = get_initial_means(X, method, r)
end = timer()
init_time = end - start
gmm = GaussianMixture(
n_components=4, means_init=ini, tol=1e-9, max_iter=2000, random_state=r
).fit(X)
times_init[method] = init_time
for i, color in enumerate(colors):
data = X[gmm.predict(X) == i]
plt.scatter(data[:, 0], data[:, 1], color=color, marker="x")
plt.scatter(
ini[:, 0], ini[:, 1], s=75, marker="D", c="orange", lw=1.5, edgecolors="black"
)
relative_times[method] = times_init[method] / times_init[methods[0]]
plt.xticks(())
plt.yticks(())
plt.title(method, loc="left", fontsize=12)
plt.title(
"Iter %i | Init Time %.2fx" % (gmm.n_iter_, relative_times[method]),
loc="right",
fontsize=10,
)
plt.suptitle("GMM iterations and relative time taken to initialize")
plt.show()
腳本的總執行時間: (0 分鐘 0.636 秒)
相關範例