该程序可以对多特征数据进行聚类

#!/usr/bin/python
# -*- coding: utf-8 -*-
from numpy import *
import random
import matplotlib.pyplot as plt

plt.figure(0, figsize=(16, 6))
fig = plt.figure(0)


class Kmeans:

    def __init__(self, x, k, maxIterations):
        self.m, self.n = x.shape
        # 第一个参数表示其所属的类簇
        self.x = x
        self.k = k
        # 实际迭代次数
        self.count = 0
        if(self.k > self.m):
            raise NameError("设置的类簇个数大于数据样本数")
        self.maxIterations = maxIterations
        self.CenterPoint = zeros((k, self.n - 1))
        self.setKCenterPoint()

    def train(self):
        # 是否有点的类簇发生变化
        ischange = 1
        while(self.count < self.maxIterations and ischange > 0):
            ischange = 0
            for i in range(self.m):
                ischange += self.computeClass(i)
            if(ischange > 0):
                self.changeCenterPoint()
            self.count += 1

    def computeClass(self, i):
        k = -1
        temp = -1
        for j in range(self.k):
            flag = sqrt(
                sum(pow((self.x[i, 1:] - self.CenterPoint[j]).A, 2)))
            if(flag < temp or temp == -1):
                temp = flag
                k = j
        if(k == self.x[i, 0]):
            return 0
        else:
            self.x[i, 0] = k
            return 1

    # 重新计算类簇中心点
    def changeCenterPoint(self):
        self.CenterPoint = zeros((self.k, self.n - 1))
        for i in range(self.k):
            temp = nonzero(self.x[:, 0] == i)[0]
            for j in temp:
                self.CenterPoint[i] += self.x[j, 1:].A[0]
            self.CenterPoint[i] /= len(temp)

    # 计算所有数据点到其最近的中心点的平均距离cost
    # 一般来说,同样的迭代次数和算法跑的次数,这个值越小代表聚类的效果越好
    # 但是在实际情况下,我们还要考虑到聚类结果的可解释性,不能一味的选择使computeCost结果值最小的那个K
    def computeCost(self):
        cost = 0
        for i in range(self.m):
            cost += sqrt(sum(pow((self.x[i, 1:] -
                                  self.CenterPoint[int(self.x[i, 0])]).A, 2)))
        return cost / self.m

    # 选取K个类簇中心点的初始值
    # 1、随机选择一个点作为第一个类簇中心
    # 2、    1)选择距离已有类簇中心最近的点
    #        2)选择距离该点最远的点
    # 循环第2步直到选出k个类簇中心点
    def setKCenterPoint(self):
        # 随机选取一个作为第一个中心点
        i = random.randint(0, self.m - 1)
        self.CenterPoint[0] = self.x[i, 1:]
        for j in range(1, self.k):
            nearestPoint = self.nearestPoint(j)
            farthestPoint = self.farthestPoint(nearestPoint)
            self.CenterPoint[j] = farthestPoint

    # 选择最近的点
    def nearestPoint(self, i):
        # 获取第二个类簇中心点时,最近的点就是第一个中心点
        if(i == 1):
            return self.CenterPoint[0]
        else:
            nearestPoint = []
            temp = -1
            for j in range(self.m):
                flag = 0
                # i=2 表示获取第3个类簇中心点,获取距离前2个类簇中心点最近的点
                for m in range(i):
                    flag += sqrt(sum(pow((self.x[j, 1:] -
                                          self.CenterPoint[m]).A, 2)))
                if(flag < temp or temp == -1):
                    temp = flag
                    nearestPoint = self.x[j, 1:]
            return nearestPoint

    # 选择最远的点
    def farthestPoint(self, nearestPoint):
        farthestPoint = []
        temp = -1
        for i in range(self.m):
            flag = sqrt(sum(pow((self.x[i, 1:] - nearestPoint).A, 2)))
            if(flag > temp and self.x[i, 1:] not in self.CenterPoint):
                temp = flag
                farthestPoint = self.x[i, 1:]
        return farthestPoint

    def draw(self):
        ax = fig.add_subplot(1, 2, 2)
        mark1 = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']
        mark2 = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
        # 根据聚类结果画出所有的点
        for i in range(self.m):
            ax.plot(self.x[i, 1], self.x[i, 2], mark1[int(self.x[i, 0])])
        # 所有类簇的中心点
        for j in range(self.k):
            ax.plot(self.CenterPoint[j, 0], self.CenterPoint[
                     j, 1], mark2[j], markersize=12)
        plt.xlabel("x1")
        plt.ylabel("x2")
        plt.show()

# ------------------------------开始-------------------------
x = []
files = open("/home/hadoop/Python/K-means/K-means.txt", "r")
for val in files.readlines():
    line = val.strip().split()
    data = [-1]
    for item in line:
        data.append(float(item))
    x.append(data)
x = mat(x)
# 通过computeCost计算cost确定K值
k = [1, 2, 3, 4, 5, 6]
cost = []
for i in k:
    model = Kmeans(x, i, 20)
    model.train()
    cost.append(model.computeCost())
ax = fig.add_subplot(1, 2, 1)
ax.plot(k, cost, "g")
plt.xlabel("k")
plt.ylabel("cost")

model = Kmeans(x, 4, 20)
model.train()
print(model.CenterPoint)
model.draw()

这里写图片描述

根据左图我们知道K=4
右图为我们聚类的结果,迭代了5次

1.658985    4.285136
-3.453687   3.424321
4.838138    -1.151539
-5.379713   -3.362104
0.972564    2.924086
-3.567919   1.531611
0.450614    -3.302219
-3.487105   -1.724432
2.668759    1.594842
-3.156485   3.191137
3.165506    -3.999838
-2.786837   -3.099354
4.208187    2.984927
-2.123337   2.943366
0.704199    -0.479481
-0.392370   -3.963704
2.831667    1.574018
-0.790153   3.343144
2.943496    -3.357075
-3.195883   -2.283926
2.336445    2.875106
-1.786345   2.554248
2.190101    -1.906020
-3.403367   -2.778288
1.778124    3.880832
-1.688346   2.230267
2.592976    -2.054368
-4.007257   -3.207066
2.257734    3.387564
-2.679011   0.785119
0.939512    -4.023563
-3.674424   -2.261084
2.046259    2.735279
-3.189470   1.780269
4.372646    -0.822248
-2.579316   -3.497576
1.889034    5.190400
-0.798747   2.185588
2.836520    -2.658556
-3.837877   -3.253815
2.096701    3.886007
-2.709034   2.923887
3.367037    -3.184789
-2.121479   -4.232586
2.329546    3.179764
-3.284816   3.273099
3.091414    -3.815232
-3.762093   -2.432191
3.542056    2.778832
-1.736822   4.241041
2.127073    -2.983680
-4.323818   -3.938116
3.792121    5.135768
-4.786473   3.358547
2.624081    -3.260715
-4.009299   -2.978115
2.493525    1.963710
-2.513661   2.642162
1.864375    -3.176309
-3.171184   -3.572452
2.894220    2.489128
-2.562539   2.884438
3.491078    -3.947487
-2.565729   -2.012114
3.332948    3.983102
-1.616805   3.573188
2.280615    -2.559444
-2.651229   -3.103198
2.321395    3.154987
-1.685703   2.939697
3.031012    -3.620252
-4.599622   -2.185829
4.196223    1.126677
-2.133863   3.093686
4.668892    -2.562705
-2.793241   -2.149706
2.884105    3.043438
-2.967647   2.848696
4.479332    -1.764772
-4.905566   -2.911070
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐