kmeans之于模式识别,如同“hello world”之于C、之于任何一门高级语言。

算法的规格(specification)

在聚类问题(一般非监督问题)中,给定训练样本X={x(1),x(2),,x(N)},每个x(i)Rd。kmeans算法的职责在于将这N个样本聚类成k个簇(cluster, μ1,μ2,,μk),流程如下:

  1. 随机选取k个聚类中心(cluster centroids)为μ1,μ2,,μk 
    C = X(randperm(m*n, k), :); # 程序语言

  2. 重复一下过程直至收敛 

    对于每一个样本i,根据最近邻(欧氏距离度量)计算其所属分类 

    c(i):=argminjx(i)μj2

    对于每一个类j,重新计算该类的质心(centroids) 
    μj:=mi=11{c(i)=j}x(i)mi=11{c(i)=j}

    }

算法的规格:

  • 一个参数k,聚类中心的数目,当然也有一些常规的参数,比如最大迭代次数epochs,容忍度tol
  • 一个循环,判断目标函数是否变化足够小,以F范数(Frobenius norm)为度归。
while true,
    ...
    if norm(J_cur-J_prev, 'fro') < tol,
        break;
    end
    J_prev = J_cur;
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 一条更新语句,更新各个类的聚类中心,根据每个样本应属的类别(欧式距离最小表征)

μj:=mi=11{c(i)=j}x(i)mi=11{c(i)=j}

这个公式看似高大上,实则不值一提,翻译过来就是新的聚类中心(centroid)在该类别空间的中心处。

    dist = sum(X.^2, 2)*ones(1, k) + (sum(C.^2, 2)*ones(1, m*n))'...
        - 2*X*C';
    [~, idx] = min(dist, [], 2) ;
    for i = 1:k,
       C(i, :) = mean(X(idx == i , :)); # 对应于这样一条语句
    end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

matlab实现

客户端(client)程序

clear all; close all;
I = imread('./lena.bmp');
[m, n, p] = size(I);
k = 7;
[C, label, J] = kmeans(I, k);
I_seg = reshape(C(label, :), m, n, p);
figure
subplot(1, 2, 1), imshow(I, []), title('原图')
subplot(1, 2, 2), imshow(uint8(I_seg), []), title('聚类图')
figure
plot(1:length(J), J), xlabel('#iterations')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

kmeans函数

function [C, label, J] = kmeans(I, k)
[m, n, p] = size(I);
X = reshape(double(I), m*n, p);
rng('default');
C = X(randperm(m*n, k), :);
J_prev = inf; iter = 0; J = []; tol = 1e-2;
while true,
    iter = iter + 1;
    dist = sum(X.^2, 2)*ones(1, k) + (sum(C.^2, 2)*ones(1, m*n))' - 2*X*C';
    [~, label] = min(dist, [], 2) ;
    for i = 1:k,
       C(i, :) = mean(X(label == i , :));
    end
    J_cur = sum(sum((X - C(label, :)).^2, 2));
    J = [J, J_cur];
    display(sprintf('#iteration: %03d, objective fcn: %f', iter, J_cur));
    if norm(J_cur-J_prev, 'fro') < tol,
        break;
    end
    J_prev = J_cur;
end
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

实验结果

目标函数收敛情况

目标函数 

J(c,μ)=i=1mx(i)μc(i)2

matlab计算程序:

J_cur = sum(sum((X - C(label, :)).^2, 2));
  • 1
  • 1

这里写图片描述

效果图

这里写图片描述

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐