# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
def kmeans(data, center_ids, max_err=0.0001, max_round=30):
init_centers = []
n = len(center_ids)
for id in center_ids:
init_centers.append(data[id, :])
error, rounds = 1.0, 0
while error > max_err and rounds < max_round:
rounds += 1
clusters = []
for _ in range(n):
clusters.append([])
for j in range(len(data)):
dist = []
for i in range(n):
vector = data[j, :] - init_centers[i]
d_ji = np.dot(vector, vector) ** 0.5
dist.append(d_ji)
near_id = sorted(enumerate(dist), key=lambda x: x[1])[0][0]
clusters[near_id].append(j)
new_center = [0] * n
error = 0
for i in range(n):
new_center[i] = np.sum(data[clusters[i], :], axis=0)
new_center[i] /= len(clusters[i])
vec = new_center[i] - init_centers[i]
err = np.dot(vec, vec) ** 0.5
if err:
init_centers[i] = new_center[i]
error += err
yield clusters, new_center, rounds # 用yield可以得到每一轮训练后的聚类情况,最终返回的是一个生成器
data = np.array([
[0.697, 0.460], [0.774, 0.376], [0.634, 0.264], [0.608, 0.318], [0.556, 0.215],
[0.403, 0.237], [0.481, 0.149], [0.437, 0.211], [0.666, 0.091], [0.243, 0.267],
[0.245, 0.057], [0.343, 0.099], [0.639, 0.161], [0.657, 0.198], [0.360, 0.370],
[0.593, 0.042], [0.719, 0.103], [0.359, 0.188], [0.339, 0.241], [0.282, 0.257],
[0.748, 0.232], [0.714, 0.346], [0.483, 0.312], [0.478, 0.437], [0.525, 0.369],
[0.751, 0.489], [0.532, 0.472], [0.473, 0.376], [0.725, 0.445], [0.446, 0.459]])
init_centers = [12, 22] # 对应的是选择的初始中心样本的id,这也同时代表了选择的聚类数目
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_xlim(0, 1)
ax.set_ylim(0, 0.6)
ax.set_ylabel('sugar')
ax.set_xlabel('density')
imgs = []
for cluster, center, rounds in kmeans(data, init_centers): # 对各轮聚类的结果进行保存,存入imgs
pics, dye = [], ['red', 'orange', 'green', 'blue', 'pink']
ax.set_title('clusters in %s rounds' % rounds)
for i, li in enumerate(cluster):
pics.append(ax.scatter(data[li, 0], data[li, 1], c=dye[i]))
pics.append(ax.scatter(center[i][0], center[i][1], s=45, c='gray', marker='s', ))
imgs.append(pics)
imgs.insert(0, [ax.scatter(data[:, 0], data[:, 1], c='k')])
A = animation.ArtistAnimation(fig, imgs, interval=1000, blit=True, repeat_delay=500)
plt.show()
A.save('3point.gif', fps=2, writer='imagemagick') # 设置保存路径,gif图每秒帧数
K-means算法的2类聚类:
K-means算法的3类聚类:
K-means算法的4类聚类: