# coding=UTF-8
import random
import matplotlib.pyplot as plt
import numpy as np
from numpy import mean, inf
from sklearn import datasets
# Kemans算法实现
class Kmeans(object):
"""
参数:
k: 分类的类数
n_iter: 最大拟合迭代次数
属性:
c_point:中心点位置
assement:预测的样本分类
"""
def __init__(self, k=3, n_iter=50):
self.k = k
self.n_iter = n_iter
self.c_point = None
self.assement = None
# 拟合
def fit(self, x):
# 将数据集向量化
x = np.array(x)
# 将预测结果初始化
self.assement = np.zeros(x.shape[0])
# 随机初始化中心点
self.c_point = self.SelectCenterPoint(x)
# 存储最短距离
min_dis = []
for i in range(x.shape[0]):
min_dis.append(inf)
# 存储上一次迭代后属于每种中心点的点集合,共K个
old_points = x
count = 0
while count < self.n_iter:
# 计算每个点最近的中心点,并将其类别更新为中心点类别
for i in range(x.shape[0]):
for j in range(self.k):
now_dis = self.euc_dis(x[i], self.c_point[j])
if now_dis < min_dis[i]:
min_dis[i] = now_dis
self.assement[i] = j
# 重新计算每个集合的中心点位置
new_points = []
for i in range(self.k):
new_points.append([])
for j in range(self.assement.shape[0]):
if self.assement[j] == i:
temp = x[j].tolist()
new_points[i].append(temp)
self.c_point[i, :] = mean(new_points[i], axis=0)
# 有时会出现某个中心点偏离数据集过远,造成类集合中无数据元素的现象,此时应重新初始化中心点
if len(set(self.assement)) != self.k:
self.c_point = self.SelectCenterPoint(x)
count = 0
continue
# 当集合中元素不再更新时则可以提前结束迭代
if self.setequal(old_points, new_points, self.k):
break
else:
old_points = new_points
count += 1
return self
# 随机初始化中心点
def SelectCenterPoint(self, x):
dimension = x.shape[1]
points = np.zeros((self.k, dimension))
for i in range(self.k):
for j in range(dimension):
x_min = np.min(x[:, j])
x_max = np.max(x[:, j])
points[i, j] = random.uniform(x_min, x_max)
return points
# 计算欧式距离
def euc_dis(self, a, b):
return np.sqrt(np.sum((a - b) ** 2))
# 判断新旧集合中元素是否发生改变
def setequal(self, a, b, k):
if len(a) != k or len(b) != k:
return False
for i in range(k):
if a[i] != b[i]:
return False
return True
# 绘制最终结果
def drawresult(self, x, y):
x0 = []
x1 = []
x2 = []
for i in range(y.shape[0]):
if y[i] == 0:
x0.append(x[i, :])
elif y[i] == 1:
x1.append(x[i, :])
elif y[i] == 2:
x2.append(x[i, :])
x0 = np.array(x0)
x1 = np.array(x1)
x2 = np.array(x2)
plt.scatter(self.c_point[:, 0], self.c_point[:, 1], c='black', alpha=0.9)
plt.scatter(x0[:, 0], x0[:, 1], c='blue', alpha=0.5)
plt.scatter(x1[:, 0], x1[:, 1], c='green', alpha=0.5)
plt.scatter(x2[:, 0], x2[:, 1], c='red', alpha=0.5)
plt.show()
if __name__ == '__main__':
# 使用鸢尾花数据集作为测试数据集
iris = datasets.load_iris()
x = iris.data
# 取其中两个维度数据便于画图
x = x[:, [1, 2]]
k = Kmeans()
k.fit(x)
k.drawresult(x, k.assement)
Python实现Kmeans算法
最新推荐文章于 2025-06-02 16:14:04 发布