JAVA实现K-means聚类

2017/04/16 20:29
阅读数 535

重点介绍下K-means聚类算法。K-means算法是比较经典的聚类算法,算法的基本思想是选取K个点(随机)作为中心进行聚类,然后对聚类的结果计算该类的质心,通过迭代的方法不断更新质心,直到质心不变或稍微移动为止,则最后的聚类结果就是最后的聚类结果。下面首先介绍下K-means具体的算法步骤。

 

K-means算法

      在前面已经大概的介绍了下K-means,下面就介绍下具体的算法描述:

1)选取K个点作为初始质心;

2)对每个样本分别计算到K个质心的相似度或距离,将该样本划分到相似度最高或距离最短的质心所在类;

3)对该轮聚类结果,计算每一个类别的质心,新的质心作为下一轮的质心;

4)判断算法是否满足终止条件,满足终止条件结束,否则继续第2、3、4步。

      在介绍算法之前,我们首先看下K-means算法聚类平面200,000个点聚成34个类别的结果(如下图)

img

 

算法实现

      K-means聚类算法整体思想比较简单,下面 就分步介绍如何用Java来实现K-means算法。

 

一、K-means算法基础属性

      在K-means算法中,有几个重要的指标,比如K值、最大迭代次数等,对于这些指标,我们统一把它们设置为类的属性,如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. private List<T> dataArray;//待分类的原始值  
  2. private int K = 3;//将要分成的类别个数  
  3. private int maxClusterTimes = 500;//最大迭代次数  
  4. private List<List<T>> clusterList;//聚类的结果  
  5. private List<T> clusteringCenterT;//质心  

 

 

二、初始质心的选择

      K-means聚类算法的结果很大程度收到初始质心的选取,这了为了保证有充分的随机性,对于初始质心的选择这里采用完全随机的方法,先把待分类的数据随机打乱,然后把前K个样本作为初始质心(通过多次迭代,会减少初始质心的影响)。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. List<T> centerT = new ArrayList<T>(size);  
  2. //对数据进行打乱  
  3. Collections.shuffle(dataArray);  
  4. for (int i = 0; i < size; i++) {  
  5.     centerT.add(dataArray.get(i));  
  6. }  

 

 

三、一轮聚类

      在K-means算法中,大部分的时间都在做一轮一轮的聚类,具体功能也很简单,就是对每一个样本分别计算和所有质心的相似度或距离,找到与该样本最相似的质心或者距离最近的质心,然后把该样本划分到该类中,具体逻辑介绍参照代码中的注释。

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. private void clustering(List<T> preCenter, int times) {  
  2.     if (preCenter == null || preCenter.size() < 2) {  
  3.         return;  
  4.     }  
  5.     //打乱质心的顺序  
  6.     Collections.shuffle(preCenter);  
  7.     List<List<T>> clusterList =  getListT(preCenter.size());  
  8.     for (T o1 : this.dataArray) {  
  9.         //寻找最相似的质心  
  10.         int max = 0;  
  11.         double maxScore = similarScore(o1, preCenter.get(0));  
  12.         for (int i = 1; i < preCenter.size(); i++) {  
  13.             if (maxScore < similarScore(o1, preCenter.get(i))) {  
  14.                 maxScore = similarScore(o1, preCenter.get(i));  
  15.                 max = i;  
  16.             }  
  17.         }  
  18.         clusterList.get(max).add(o1);  
  19.     }  
  20.     //计算本次聚类结果每个类别的质心  
  21.     List<T> nowCenter = new ArrayList<T> ();  
  22.     for (List<T> list : clusterList) {  
  23.         nowCenter.add(getCenterT(list));  
  24.     }  
  25.     //是否达到最大迭代次数  
  26.     if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  27.         this.clusterList = clusterList;  
  28.         return;  
  29.     }  
  30.     this.clusteringCenterT = nowCenter;  
  31.     //判断质心是否发生移动,如果没有移动,结束本次聚类,否则进行下一轮  
  32.     if (isCenterChange(preCenter, nowCenter)) {  
  33.         clear(clusterList);  
  34.         clustering(nowCenter, times + 1);  
  35.     } else {  
  36.         this.clusterList = clusterList;  
  37.     }  
  38. }  

 

 

四、质心是否移动

      在第三步中,提到了一个重要的步骤:每轮聚类结束后,都要重新计算质心,并且计算质心是否发生移动。对于新质心的计算、样本之间的相似度和判断两个样本是否相等这几个功能由于并不知道样本的具体数据类型,因此把他们定义成抽象方法,供子类来实现。下面就重点介绍如何判断质心是否发生移动。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  2.     if (preT == null || nowT == null) {  
  3.         return false;  
  4.     }  
  5.     for (T t1 : preT) {  
  6.         boolean bol = true;  
  7.         for (T t2 : nowT) {  
  8.             if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动  
  9.                 bol = false;  
  10.                 break;  
  11.             }  
  12.         }  
  13.         //有一个质心发生移动,认为需要进行下一次计算  
  14.         if (bol) {  
  15.             return bol;  
  16.         }  
  17.     }  
  18.     return false;  
  19. }  

      从上述代码可以看到,算法的思想就是对于前后两个质心数组分别前一组的质心是否在后一个质心组中出现,有一个没有出现,就认为质心发生了变动。

 

完整代码

      上面四步已经完整的介绍了K-means算法的具体算法思想,下面就看下完整的代码实现。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1.  /**   
  2.  *@Description:  K-means聚类 
  3.  */   
  4. package com.lulei.datamining.knn;    
  5.   
  6. import java.util.ArrayList;  
  7. import java.util.Collections;  
  8. import java.util.List;  
  9.     
  10. public abstract class KMeansClustering <T>{  
  11.     private List<T> dataArray;//待分类的原始值  
  12.     private int K = 3;//将要分成的类别个数  
  13.     private int maxClusterTimes = 500;//最大迭代次数  
  14.     private List<List<T>> clusterList;//聚类的结果  
  15.     private List<T> clusteringCenterT;//质心  
  16.       
  17.     public int getK() {  
  18.         return K;  
  19.     }  
  20.     public void setK(int K) {  
  21.         if (K < 1) {  
  22.             throw new IllegalArgumentException("K must greater than 0");  
  23.         }  
  24.         this.K = K;  
  25.     }  
  26.     public int getMaxClusterTimes() {  
  27.         return maxClusterTimes;  
  28.     }  
  29.     public void setMaxClusterTimes(int maxClusterTimes) {  
  30.         if (maxClusterTimes < 10) {  
  31.             throw new IllegalArgumentException("maxClusterTimes must greater than 10");  
  32.         }  
  33.         this.maxClusterTimes = maxClusterTimes;  
  34.     }  
  35.     public List<T> getClusteringCenterT() {  
  36.         return clusteringCenterT;  
  37.     }  
  38.     /** 
  39.      * @return 
  40.      * @Author:lulei   
  41.      * @Description: 对数据进行聚类 
  42.      */  
  43.     public List<List<T>> clustering() {  
  44.         if (dataArray == null) {  
  45.             return null;  
  46.         }  
  47.         //初始K个点为数组中的前K个点  
  48.         int size = K > dataArray.size() ? dataArray.size() : K;  
  49.         List<T> centerT = new ArrayList<T>(size);  
  50.         //对数据进行打乱  
  51.         Collections.shuffle(dataArray);  
  52.         for (int i = 0; i < size; i++) {  
  53.             centerT.add(dataArray.get(i));  
  54.         }  
  55.         clustering(centerT, 0);  
  56.         return clusterList;  
  57.     }  
  58.       
  59.     /** 
  60.      * @param preCenter 
  61.      * @param times 
  62.      * @Author:lulei   
  63.      * @Description: 一轮聚类 
  64.      */  
  65.     private void clustering(List<T> preCenter, int times) {  
  66.         if (preCenter == null || preCenter.size() < 2) {  
  67.             return;  
  68.         }  
  69.         //打乱质心的顺序  
  70.         Collections.shuffle(preCenter);  
  71.         List<List<T>> clusterList =  getListT(preCenter.size());  
  72.         for (T o1 : this.dataArray) {  
  73.             //寻找最相似的质心  
  74.             int max = 0;  
  75.             double maxScore = similarScore(o1, preCenter.get(0));  
  76.             for (int i = 1; i < preCenter.size(); i++) {  
  77.                 if (maxScore < similarScore(o1, preCenter.get(i))) {  
  78.                     maxScore = similarScore(o1, preCenter.get(i));  
  79.                     max = i;  
  80.                 }  
  81.             }  
  82.             clusterList.get(max).add(o1);  
  83.         }  
  84.         //计算本次聚类结果每个类别的质心  
  85.         List<T> nowCenter = new ArrayList<T> ();  
  86.         for (List<T> list : clusterList) {  
  87.             nowCenter.add(getCenterT(list));  
  88.         }  
  89.         //是否达到最大迭代次数  
  90.         if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  91.             this.clusterList = clusterList;  
  92.             return;  
  93.         }  
  94.         this.clusteringCenterT = nowCenter;  
  95.         //判断质心是否发生移动,如果没有移动,结束本次聚类,否则进行下一轮  
  96.         if (isCenterChange(preCenter, nowCenter)) {  
  97.             clear(clusterList);  
  98.             clustering(nowCenter, times + 1);  
  99.         } else {  
  100.             this.clusterList = clusterList;  
  101.         }  
  102.     }  
  103.       
  104.     /** 
  105.      * @param size 
  106.      * @return 
  107.      * @Author:lulei   
  108.      * @Description: 初始化一个聚类结果 
  109.      */  
  110.     private List<List<T>> getListT(int size) {  
  111.         List<List<T>> list = new ArrayList<List<T>>(size);  
  112.         for (int i = 0; i < size; i++) {  
  113.             list.add(new ArrayList<T>());  
  114.         }  
  115.         return list;  
  116.     }  
  117.       
  118.     /** 
  119.      * @param lists 
  120.      * @Author:lulei   
  121.      * @Description: 清空无用数组 
  122.      */  
  123.     private void clear(List<List<T>> lists) {  
  124.         for (List<T> list : lists) {  
  125.             list.clear();  
  126.         }  
  127.         lists.clear();  
  128.     }  
  129.       
  130.     /** 
  131.      * @param value 
  132.      * @Author:lulei   
  133.      * @Description: 向模型中添加记录 
  134.      */  
  135.     public void addRecord(T value) {  
  136.         if (dataArray == null) {  
  137.             dataArray = new ArrayList<T>();  
  138.         }  
  139.         dataArray.add(value);  
  140.     }  
  141.       
  142.     /** 
  143.      * @param preT 
  144.      * @param nowT 
  145.      * @return 
  146.      * @Author:lulei   
  147.      * @Description: 判断质心是否发生移动 
  148.      */  
  149.     private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  150.         if (preT == null || nowT == null) {  
  151.             return false;  
  152.         }  
  153.         for (T t1 : preT) {  
  154.             boolean bol = true;  
  155.             for (T t2 : nowT) {  
  156.                 if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动  
  157.                     bol = false;  
  158.                     break;  
  159.                 }  
  160.             }  
  161.             //有一个质心发生移动,认为需要进行下一次计算  
  162.             if (bol) {  
  163.                 return bol;  
  164.             }  
  165.         }  
  166.         return false;  
  167.     }  
  168.       
  169.     /** 
  170.      * @param o1 
  171.      * @param o2 
  172.      * @return 
  173.      * @Author:lulei   
  174.      * @Description: o1 o2之间的相似度 
  175.      */  
  176.     public abstract double similarScore(T o1, T o2);  
  177.       
  178.     /** 
  179.      * @param o1 
  180.      * @param o2 
  181.      * @return 
  182.      * @Author:lulei   
  183.      * @Description: 判断o1 o2是否相等 
  184.      */  
  185.     public abstract boolean equals(T o1, T o2);  
  186.       
  187.     /** 
  188.      * @param list 
  189.      * @return 
  190.      * @Author:lulei   
  191.      * @Description: 求一组数据的质心 
  192.      */  
  193.     public abstract T getCenterT(List<T> list);  
  194. }  

 

二维数聚类实现

      在算法描述中,介绍了一个200,000个点聚成34个类别的效果图,下面就针对二维坐标数据实现其具体子类。

 

一、相似度

      对于二维坐标的相似度,这里我们采取两点间聚类的相反数,具体实现如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. @Override  
  2. public double similarScore(XYbean o1, XYbean o2) {  
  3.     double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));  
  4.     return distance * -1;  
  5. }  

 

 

二、样本/质心是否相等

      判断样本/质心是否相等只需要判断两点的坐标是否相等即可,具体实现如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. @Override  
  2. public boolean equals(XYbean o1, XYbean o2) {  
  3.     return o1.getX() == o2.getX() && o1.getY() == o2.getY();  
  4. }  

 

 

三、获取一个分类下的新质心

      对于二维坐标数据,可以使用所有点的重心作为分类的质心,具体如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. @Override  
  2. public XYbean getCenterT(List<XYbean> list) {  
  3.     int x = 0;  
  4.     int y = 0;  
  5.     try {  
  6.         for (XYbean xy : list) {  
  7.             x += xy.getX();  
  8.             y += xy.getY();  
  9.         }  
  10.         x = x / list.size();  
  11.         y = y / list.size();  
  12.     } catch(Exception e) {  
  13.           
  14.     }  
  15.     return new XYbean(x, y);  
  16. }  

 

 

四、main方法

      对于具体二维坐标的源码这里就不再贴出来,就是实现前面介绍的抽象类,并实现其中的3个抽象方法,下面我们就随机产生200,000个点,然后聚成34个类别,具体代码如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. public static void main(String[] args) {  
  2.       
  3.     int width = 600;  
  4.     int height = 400;  
  5.     int K = 34;  
  6.     XYCluster xyCluster = new XYCluster();  
  7.     for (int i = 0; i < 200000; i++) {  
  8.         int x = (int)(Math.random() * width) + 1;  
  9.         int y = (int)(Math.random() * height) + 1;  
  10.         xyCluster.addRecord(new XYbean(x, y));  
  11.     }  
  12.     xyCluster.setK(K);  
  13.     long a = System.currentTimeMillis();  
  14.     List<List<XYbean>> cresult = xyCluster.clustering();  
  15.     List<XYbean> center = xyCluster.getClusteringCenterT();  
  16.     System.out.println(JsonUtil.parseJson(center));  
  17.     long b = System.currentTimeMillis();  
  18.     System.out.println("耗时:" + (b - a) + "ms");  
  19.     new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);  
  20. }  

 

 

      对于这随机产生的200,000个点聚成34类,总耗时5485ms。(计算机配置:i5 + 8G内存)

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部