文档章节

JAVA实现KNN分类

一贱书生
 一贱书生
发布于 2017/04/16 20:26
字数 3936
阅读 47
收藏 0

 KNN算法又叫近邻算法,是数据挖掘中一种常用的分类算法,接单的介绍KNN算法的核心思想就是:寻找与目标最近的K个个体,这些样本属于类别最多的那个类别就是目标的类别。比如K为7,那么我们就从数据中找到和目标最近(或者相似度最高)的7个样本,加入这7个样本对应的类别分别为A、B、C、A、A、A、B,那么目标属于的分类就是A(因为这7个样本中属于A类别的样本个数最多)。

 

算法实现

一、训练数据格式定义

      下面就简单的介绍下如何用Java来实现KNN分类,首先我们需要存储训练集(包括属性以及对应的类别),这里我们对未知的属性使用泛型,类别我们使用字符串存储。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1.  /**   
  2.  *@Description:  KNN分类模型中一条记录的存储格式 
  3.  */   
  4. package com.lulei.datamining.knn.bean;    
  5.     
  6. public class KnnValueBean<T>{  
  7.     private T value;//记录值  
  8.     private String typeId;//分类ID  
  9.       
  10.     public KnnValueBean(T value, String typeId) {  
  11.         this.value = value;  
  12.         this.typeId = typeId;  
  13.     }  
  14.   
  15.     public T getValue() {  
  16.         return value;  
  17.     }  
  18.   
  19.     public void setValue(T value) {  
  20.         this.value = value;  
  21.     }  
  22.   
  23.     public String getTypeId() {  
  24.         return typeId;  
  25.     }  
  26.   
  27.     public void setTypeId(String typeId) {  
  28.         this.typeId = typeId;  
  29.     }  
  30. }  


二、K个最近邻类别数据格式定义

 

      在统计得到K个最近邻中,我们需要记录前K个样本的分类以及对应的相似度,我们这里使用如下数据格式:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1.  /**   
  2.  *@Description: K个最近邻的类别得分 
  3.  */   
  4. package com.lulei.datamining.knn.bean;    
  5.     
  6. public class KnnValueSort {  
  7.     private String typeId;//分类ID  
  8.     private double score;//该分类得分  
  9.       
  10.     public KnnValueSort(String typeId, double score) {  
  11.         this.typeId = typeId;  
  12.         this.score = score;  
  13.     }  
  14.     public String getTypeId() {  
  15.         return typeId;  
  16.     }  
  17.     public void setTypeId(String typeId) {  
  18.         this.typeId = typeId;  
  19.     }  
  20.     public double getScore() {  
  21.         return score;  
  22.     }  
  23.     public void setScore(double score) {  
  24.         this.score = score;  
  25.     }  
  26. }  


三、KNN算法基本属性

 

      在KNN算法中,最重要的一个指标就是K的取值,因此我们在基类中需要设置一个属性K以及设置一个数组用于存储已知分类的数据。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. private List<KnnValueBean> dataArray;  
  2. private int K = 3;  


四、添加已知分类数据

 

      在使用KNN分类之前,我们需要先向其中添加我们已知分类的数据,我们后面就是使用这些数据来预测未知数据的分类。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. /** 
  2.  * @param value 
  3.  * @param typeId 
  4.  * @Author:lulei   
  5.  * @Description: 向模型中添加记录 
  6.  */  
  7. public void addRecord(T value, String typeId) {  
  8.     if (dataArray == null) {  
  9.         dataArray = new ArrayList<KnnValueBean>();  
  10.     }  
  11.     dataArray.add(new KnnValueBean<T>(value, typeId));  
  12. }  


五、两个样本之间的相似度(或者距离)

 

      在KNN算法中,最重要的一个方法就是如何确定两个样本之间的相似度(或者距离),由于这里我们使用的是泛型,并没有办法确定两个对象之间的相似度,一次这里我们把它设置为抽象方法,让子类来实现。这里我们方法定义为相似度,也就是返回值越大,两者越相似,之间的距离越短

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. /** 
  2.  * @param o1 
  3.  * @param o2 
  4.  * @return 
  5.  * @Author:lulei   
  6.  * @Description: o1 o2之间的相似度 
  7.  */  
  8. public abstract double similarScore(T o1, T o2);  


六、获取最近的K个样本的分类

 

      KNN算法的核心思想就是找到最近的K个近邻,因此这一步也是整个算法的核心部分。这里我们使用数组来保存相似度最大的K个样本的分类和相似度,在计算的过程中通过循环遍历所有的样本,数组保存截至当前计算点最相似的K个样本对应的类别和相似度,具体实现如下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. /** 
  2.  * @param value 
  3.  * @return 
  4.  * @Author:lulei   
  5.  * @Description: 获取距离最近的K个分类 
  6.  */  
  7. private KnnValueSort[] getKType(T value) {  
  8.     int k = 0;  
  9.     KnnValueSort[] topK = new KnnValueSort[K];  
  10.     for (KnnValueBean<T> bean : dataArray) {  
  11.         double score = similarScore(bean.getValue(), value);  
  12.         if (k == 0) {  
  13.             //数组中的记录个数为0是直接添加  
  14.             topK[k] = new KnnValueSort(bean.getTypeId(), score);  
  15.             k++;  
  16.         } else {  
  17.             if (!(k == K && score < topK[k -1].getScore())){  
  18.                 int i = 0;  
  19.                 //找到要插入的点  
  20.                 for (; i < k && score < topK[i].getScore(); i++);  
  21.                 int j = k - 1;  
  22.                 if (k < K) {  
  23.                     j = k;  
  24.                     k++;  
  25.                 }  
  26.                 for (; j > i; j--) {  
  27.                     topK[j] = topK[j - 1];  
  28.                 }  
  29.                 topK[i] = new KnnValueSort(bean.getTypeId(), score);  
  30.             }  
  31.         }  
  32.     }  
  33.     return topK;  
  34. }  


七、统计K个样本出现次数最多的类别

 

      这一步就是一个简单的计数,统计K个样本中出现次数最多的分类,该分类就是我们要预测的目标数据的分类。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1. /** 
  2.  * @param value 
  3.  * @return 
  4.  * @Author:lulei   
  5.  * @Description: KNN分类判断value的类别 
  6.  */  
  7. public String getTypeId(T value) {  
  8.     KnnValueSort[] array = getKType(value);  
  9.     HashMap<String, Integer> map = new HashMap<String, Integer>(K);  
  10.     for (KnnValueSort bean : array) {  
  11.         if (bean != null) {  
  12.             if (map.containsKey(bean.getTypeId())) {  
  13.                 map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);  
  14.             } else {  
  15.                 map.put(bean.getTypeId(), 1);  
  16.             }  
  17.         }  
  18.     }  
  19.     String maxTypeId = null;  
  20.     int maxCount = 0;  
  21.     Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();  
  22.     while (iter.hasNext()) {  
  23.         Entry<String, Integer> entry = iter.next();  
  24.         if (maxCount < entry.getValue()) {  
  25.             maxCount = entry.getValue();  
  26.             maxTypeId = entry.getKey();  
  27.         }  
  28.     }  
  29.     return maxTypeId;  
  30. }  


      到现在为止KNN分类的抽象基类已经编写完成,在测试之前我们先多说几句,KNN分类是统计K个样本中出现次数最多的分类,这种在有些情况下并不是特别合理,比如K=5,前5个样本对应的分类分别为A、A、B、B、B,对应的相似度得分分别为10、9、2、2、1,如果使用上面的方法,那预测的分类就是B,但是看这些数据,预测的分类是A感觉更合理。基于这种情况,自己对KNN算法提出如下优化(这里并不提供代码,只提供简单的思路):在获取最相似K个样本和相似度后,可以对相似度和出现次数K做一种函数运算,比如加权,得到的函数值最大的分类就是目标的预测分类。

 

基类源码

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1.  /**   
  2.  *@Description: KNN分类 
  3.  */   
  4. package com.lulei.datamining.knn;    
  5.   
  6. import java.util.ArrayList;  
  7. import java.util.HashMap;  
  8. import java.util.Iterator;  
  9. import java.util.List;  
  10. import java.util.Map.Entry;  
  11.   
  12. import com.lulei.datamining.knn.bean.KnnValueBean;  
  13. import com.lulei.datamining.knn.bean.KnnValueSort;  
  14. import com.lulei.util.JsonUtil;  
  15.     
  16. @SuppressWarnings({"rawtypes"})  
  17. public abstract class KnnClassification<T> {  
  18.     private List<KnnValueBean> dataArray;  
  19.     private int K = 3;  
  20.       
  21.     public int getK() {  
  22.         return K;  
  23.     }  
  24.     public void setK(int K) {  
  25.         if (K < 1) {  
  26.             throw new IllegalArgumentException("K must greater than 0");  
  27.         }  
  28.         this.K = K;  
  29.     }  
  30.   
  31.     /** 
  32.      * @param value 
  33.      * @param typeId 
  34.      * @Author:lulei   
  35.      * @Description: 向模型中添加记录 
  36.      */  
  37.     public void addRecord(T value, String typeId) {  
  38.         if (dataArray == null) {  
  39.             dataArray = new ArrayList<KnnValueBean>();  
  40.         }  
  41.         dataArray.add(new KnnValueBean<T>(value, typeId));  
  42.     }  
  43.       
  44.     /** 
  45.      * @param value 
  46.      * @return 
  47.      * @Author:lulei   
  48.      * @Description: KNN分类判断value的类别 
  49.      */  
  50.     public String getTypeId(T value) {  
  51.         KnnValueSort[] array = getKType(value);  
  52.         System.out.println(JsonUtil.parseJson(array));  
  53.         HashMap<String, Integer> map = new HashMap<String, Integer>(K);  
  54.         for (KnnValueSort bean : array) {  
  55.             if (bean != null) {  
  56.                 if (map.containsKey(bean.getTypeId())) {  
  57.                     map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);  
  58.                 } else {  
  59.                     map.put(bean.getTypeId(), 1);  
  60.                 }  
  61.             }  
  62.         }  
  63.         String maxTypeId = null;  
  64.         int maxCount = 0;  
  65.         Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();  
  66.         while (iter.hasNext()) {  
  67.             Entry<String, Integer> entry = iter.next();  
  68.             if (maxCount < entry.getValue()) {  
  69.                 maxCount = entry.getValue();  
  70.                 maxTypeId = entry.getKey();  
  71.             }  
  72.         }  
  73.         return maxTypeId;  
  74.     }  
  75.       
  76.     /** 
  77.      * @param value 
  78.      * @return 
  79.      * @Author:lulei   
  80.      * @Description: 获取距离最近的K个分类 
  81.      */  
  82.     private KnnValueSort[] getKType(T value) {  
  83.         int k = 0;  
  84.         KnnValueSort[] topK = new KnnValueSort[K];  
  85.         for (KnnValueBean<T> bean : dataArray) {  
  86.             double score = similarScore(bean.getValue(), value);  
  87.             if (k == 0) {  
  88.                 //数组中的记录个数为0是直接添加  
  89.                 topK[k] = new KnnValueSort(bean.getTypeId(), score);  
  90.                 k++;  
  91.             } else {  
  92.                 if (!(k == K && score < topK[k -1].getScore())){  
  93.                     int i = 0;  
  94.                     //找到要插入的点  
  95.                     for (; i < k && score < topK[i].getScore(); i++);  
  96.                     int j = k - 1;  
  97.                     if (k < K) {  
  98.                         j = k;  
  99.                         k++;  
  100.                     }  
  101.                     for (; j > i; j--) {  
  102.                         topK[j] = topK[j - 1];  
  103.                     }  
  104.                     topK[i] = new KnnValueSort(bean.getTypeId(), score);  
  105.                 }  
  106.             }  
  107.         }  
  108.         return topK;  
  109.     }  
  110.       
  111.     /** 
  112.      * @param o1 
  113.      * @param o2 
  114.      * @return 
  115.      * @Author:lulei   
  116.      * @Description: o1 o2之间的相似度 
  117.      */  
  118.     public abstract double similarScore(T o1, T o2);  
  119. }  

 

 

具体子类实现

      对于上面介绍的都在KNN分类的抽象基类中,对于实际的问题我们需要继承基类并实现基类中的相似度抽象方法,这里我们做一个简单的实现。

 

[java] view plain copy

 print?在CODE上查看代码片派生到我的代码片

  1.  /**   
  2.  *@Description:      
  3.  */   
  4. package com.lulei.datamining.knn.test;    
  5.   
  6. import com.lulei.datamining.knn.KnnClassification;  
  7. import com.lulei.util.JsonUtil;  
  8.     
  9. public class Test extends KnnClassification<Integer>{  
  10.       
  11.     @Override  
  12.     public double similarScore(Integer o1, Integer o2) {  
  13.         return -1 * Math.abs(o1 - o2);  
  14.     }  
  15.       
  16.     /**   
  17.      * @param args 
  18.      * @Author:lulei   
  19.      * @Description:   
  20.      */  
  21.     public static void main(String[] args) {  
  22.         Test test = new Test();  
  23.         for (int i = 1; i < 10; i++) {  
  24.             test.addRecord(i, i > 5 ? "0" : "1");  
  25.         }  
  26.         System.out.println(JsonUtil.parseJson(test.getTypeId(0)));  
  27.           
  28.     }  
  29. }  

 

 

      这里我们一共添加了1、2、3、4、5、6、7、8、9这9组数据,前5组的类别为1,后4组的类别为0,两个数据之间的相似度为两者之间的差值的绝对值的相反数,下面预测0应该属于的分类,这里K的默认值为3,因此最近的K个样本分别为1、2、3,对应的分类分别为"1"、"1"、"1",因为最后预测的分类为"1"。

 

KNN算法全名为k-Nearest Neighbor,就是K最近邻的意思。KNN也是一种分类算法。但是与之前说的决策树分类算法相比,这个算法算是最简单的一个了。算法的主要过程为:

1、给定一个训练集数据,每个训练集数据都是已经分好类的。
2、设定一个初始的测试数据a,计算a到训练集所有数据的欧几里得距离,并排序。                       

3、选出训练集中离a距离最近的K个训练集数据。

4、比较k个训练集数据,选出里面出现最多的分类类型,此分类类型即为最终测试数据a的分类。

下面百度百科上的一张简图:

KNN算法实现

首先测试数据需要2块,1个是训练集数据,就是已经分好类的数据,比如上图中的非绿色的点。还有一个是测试数据,就是上面的绿点,当然这里的测试数据不会是一个,而是一组。这里的数据与数据之间的距离用数据的特征向量做计算,特征向量可以是多维度的。通过计算特征向量与特征向量之间的欧几里得距离来推算相似度。定义训练集数据trainInput.txt:

 

[java] view plain copy

 print?

  1. a 1 2 3 4 5   
  2. b 5 4 3 2 1   
  3. c 3 3 3 3 3   
  4. d -3 -3 -3 -3 -3   
  5. a 1 2 3 4 4   
  6. b 4 4 3 2 1   
  7. c 3 3 3 2 4   
  8. d 0 0 1 1 -2   

待测试数据testInput,只有特征向量值:

 

 

[java] view plain copy

 print?

  1. 1 2 3 2 4   
  2. 2 3 4 2 1   
  3. 8 7 2 3 5   
  4. -3 -2 2 4 0   
  5. -4 -4 -4 -4 -4   
  6. 1 2 3 4 4   
  7. 4 4 3 2 1   
  8. 3 3 3 2 4   
  9. 0 0 1 1 -2   

下面是主程序:

 

 

[java] view plain copy

 print?

  1. package DataMing_KNN;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.util.ArrayList;  
  8. import java.util.Arrays;  
  9. import java.util.Collection;  
  10. import java.util.Collections;  
  11. import java.util.Comparator;  
  12. import java.util.HashMap;  
  13. import java.util.Map;  
  14.   
  15. import org.apache.activemq.filter.ComparisonExpression;  
  16.   
  17. /** 
  18.  * k最近邻算法工具类 
  19.  *  
  20.  * @author lyq 
  21.  *  
  22.  */  
  23. public class KNNTool {  
  24.     // 为4个类别设置权重,默认权重比一致  
  25.     public int[] classWeightArray = new int[] { 1, 1, 1, 1 };  
  26.     // 测试数据地址  
  27.     private String testDataPath;  
  28.     // 训练集数据地址  
  29.     private String trainDataPath;  
  30.     // 分类的不同类型  
  31.     private ArrayList<String> classTypes;  
  32.     // 结果数据  
  33.     private ArrayList<Sample> resultSamples;  
  34.     // 训练集数据列表容器  
  35.     private ArrayList<Sample> trainSamples;  
  36.     // 训练集数据  
  37.     private String[][] trainData;  
  38.     // 测试集数据  
  39.     private String[][] testData;  
  40.   
  41.     public KNNTool(String trainDataPath, String testDataPath) {  
  42.         this.trainDataPath = trainDataPath;  
  43.         this.testDataPath = testDataPath;  
  44.         readDataFormFile();  
  45.     }  
  46.   
  47.     /** 
  48.      * 从文件中阅读测试数和训练数据集 
  49.      */  
  50.     private void readDataFormFile() {  
  51.         ArrayList<String[]> tempArray;  
  52.   
  53.         tempArray = fileDataToArray(trainDataPath);  
  54.         trainData = new String[tempArray.size()][];  
  55.         tempArray.toArray(trainData);  
  56.   
  57.         classTypes = new ArrayList<>();  
  58.         for (String[] s : tempArray) {  
  59.             if (!classTypes.contains(s[0])) {  
  60.                 // 添加类型  
  61.                 classTypes.add(s[0]);  
  62.             }  
  63.         }  
  64.   
  65.         tempArray = fileDataToArray(testDataPath);  
  66.         testData = new String[tempArray.size()][];  
  67.         tempArray.toArray(testData);  
  68.     }  
  69.   
  70.     /** 
  71.      * 将文件转为列表数据输出 
  72.      *  
  73.      * @param filePath 
  74.      *            数据文件的内容 
  75.      */  
  76.     private ArrayList<String[]> fileDataToArray(String filePath) {  
  77.         File file = new File(filePath);  
  78.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  79.   
  80.         try {  
  81.             BufferedReader in = new BufferedReader(new FileReader(file));  
  82.             String str;  
  83.             String[] tempArray;  
  84.             while ((str = in.readLine()) != null) {  
  85.                 tempArray = str.split(" ");  
  86.                 dataArray.add(tempArray);  
  87.             }  
  88.             in.close();  
  89.         } catch (IOException e) {  
  90.             e.getStackTrace();  
  91.         }  
  92.   
  93.         return dataArray;  
  94.     }  
  95.   
  96.     /** 
  97.      * 计算样本特征向量的欧几里得距离 
  98.      *  
  99.      * @param f1 
  100.      *            待比较样本1 
  101.      * @param f2 
  102.      *            待比较样本2 
  103.      * @return 
  104.      */  
  105.     private int computeEuclideanDistance(Sample s1, Sample s2) {  
  106.         String[] f1 = s1.getFeatures();  
  107.         String[] f2 = s2.getFeatures();  
  108.         // 欧几里得距离  
  109.         int distance = 0;  
  110.   
  111.         for (int i = 0; i < f1.length; i++) {  
  112.             int subF1 = Integer.parseInt(f1[i]);  
  113.             int subF2 = Integer.parseInt(f2[i]);  
  114.   
  115.             distance += (subF1 - subF2) * (subF1 - subF2);  
  116.         }  
  117.   
  118.         return distance;  
  119.     }  
  120.   
  121.     /** 
  122.      * 计算K最近邻 
  123.      * @param k 
  124.      * 在多少的k范围内 
  125.      */  
  126.     public void knnCompute(int k) {  
  127.         String className = "";  
  128.         String[] tempF = null;  
  129.         Sample temp;  
  130.         resultSamples = new ArrayList<>();  
  131.         trainSamples = new ArrayList<>();  
  132.         // 分类类别计数  
  133.         HashMap<String, Integer> classCount;  
  134.         // 类别权重比  
  135.         HashMap<String, Integer> classWeight = new HashMap<>();  
  136.         // 首先讲测试数据转化到结果数据中  
  137.         for (String[] s : testData) {  
  138.             temp = new Sample(s);  
  139.             resultSamples.add(temp);  
  140.         }  
  141.   
  142.         for (String[] s : trainData) {  
  143.             className = s[0];  
  144.             tempF = new String[s.length - 1];  
  145.             System.arraycopy(s, 1, tempF, 0, s.length - 1);  
  146.             temp = new Sample(className, tempF);  
  147.             trainSamples.add(temp);  
  148.         }  
  149.   
  150.         // 离样本最近排序的的训练集数据  
  151.         ArrayList<Sample> kNNSample = new ArrayList<>();  
  152.         // 计算训练数据集中离样本数据最近的K个训练集数据  
  153.         for (Sample s : resultSamples) {  
  154.             classCount = new HashMap<>();  
  155.             int index = 0;  
  156.             for (String type : classTypes) {  
  157.                 // 开始时计数为0  
  158.                 classCount.put(type, 0);  
  159.                 classWeight.put(type, classWeightArray[index++]);  
  160.             }  
  161.             for (Sample tS : trainSamples) {  
  162.                 int dis = computeEuclideanDistance(s, tS);  
  163.                 tS.setDistance(dis);  
  164.             }  
  165.   
  166.             Collections.sort(trainSamples);  
  167.             kNNSample.clear();  
  168.             // 挑选出前k个数据作为分类标准  
  169.             for (int i = 0; i < trainSamples.size(); i++) {  
  170.                 if (i < k) {  
  171.                     kNNSample.add(trainSamples.get(i));  
  172.                 } else {  
  173.                     break;  
  174.                 }  
  175.             }  
  176.             // 判定K个训练数据的多数的分类标准  
  177.             for (Sample s1 : kNNSample) {  
  178.                 int num = classCount.get(s1.getClassName());  
  179.                 // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小  
  180.                 num += classWeight.get(s1.getClassName());  
  181.                 classCount.put(s1.getClassName(), num);  
  182.             }  
  183.   
  184.             int maxCount = 0;  
  185.             // 筛选出k个训练集数据中最多的一个分类  
  186.             for (Map.Entry entry : classCount.entrySet()) {  
  187.                 if ((Integer) entry.getValue() > maxCount) {  
  188.                     maxCount = (Integer) entry.getValue();  
  189.                     s.setClassName((String) entry.getKey());  
  190.                 }  
  191.             }  
  192.   
  193.             System.out.print("测试数据特征:");  
  194.             for (String s1 : s.getFeatures()) {  
  195.                 System.out.print(s1 + " ");  
  196.             }  
  197.             System.out.println("分类:" + s.getClassName());  
  198.         }  
  199.     }  
  200. }  

Sample样本数据类:

 

 

[java] view plain copy

 print?

  1. package DataMing_KNN;  
  2.   
  3. /** 
  4.  * 样本数据类 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Sample implements Comparable<Sample>{  
  10.     // 样本数据的分类名称  
  11.     private String className;  
  12.     // 样本数据的特征向量  
  13.     private String[] features;  
  14.     //测试样本之间的间距值,以此做排序  
  15.     private Integer distance;  
  16.       
  17.     public Sample(String[] features){  
  18.         this.features = features;  
  19.     }  
  20.       
  21.     public Sample(String className, String[] features){  
  22.         this.className = className;  
  23.         this.features = features;  
  24.     }  
  25.   
  26.     public String getClassName() {  
  27.         return className;  
  28.     }  
  29.   
  30.     public void setClassName(String className) {  
  31.         this.className = className;  
  32.     }  
  33.   
  34.     public String[] getFeatures() {  
  35.         return features;  
  36.     }  
  37.   
  38.     public void setFeatures(String[] features) {  
  39.         this.features = features;  
  40.     }  
  41.   
  42.     public Integer getDistance() {  
  43.         return distance;  
  44.     }  
  45.   
  46.     public void setDistance(int distance) {  
  47.         this.distance = distance;  
  48.     }  
  49.   
  50.     @Override  
  51.     public int compareTo(Sample o) {  
  52.         // TODO Auto-generated method stub  
  53.         return this.getDistance().compareTo(o.getDistance());  
  54.     }  
  55.       
  56. }  

测试场景类:

 

 

[java] view plain copy

 print?

  1. /** 
  2.  * k最近邻算法场景类型 
  3.  * @author lyq 
  4.  * 
  5.  */  
  6. public class Client {  
  7.     public static void main(String[] args){  
  8.         String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";  
  9.         String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";  
  10.           
  11.         KNNTool tool = new KNNTool(trainDataPath, testDataPath);  
  12.         tool.knnCompute(3);  
  13.           
  14.     }  
  15.       
  16.   
  17.   
  18. }  

执行的结果为:

 

 

[java] view plain copy

 print?

  1. 测试数据特征:1 2 3 2 4 分类:a  
  2. 测试数据特征:2 3 4 2 1 分类:c  
  3. 测试数据特征:8 7 2 3 5 分类:b  
  4. 测试数据特征:-3 -2 2 4 0 分类:a  
  5. 测试数据特征:-4 -4 -4 -4 -4 分类:d  
  6. 测试数据特征:1 2 3 4 4 分类:a  
  7. 测试数据特征:4 4 3 2 1 分类:b  
  8. 测试数据特征:3 3 3 2 4 分类:c  
  9. 测试数据特征:0 0 1 1 -2 分类:d  

 

程序的输出结果如上所示,如果不相信的话可以自己动手计算进行验证。

KNN算法的注意点:

1、knn算法的训练集数据必须要相对公平,各个类型的数据数量应该是平均的,否则当A数据由1000个B数据由100个,到时无论如何A数据的样本还是占优的。

2、knn算法如果纯粹凭借分类的多少做判断,还是可以继续优化的,比如近的数据的权重可以设大,最后根据所有的类型权重和进行比较,而不是单纯的凭借数量。

3、knn算法的缺点是计算量大,这个从程序中也应该看得出来,里面每个测试数据都要计算到所有的训练集数据之间的欧式距离,时间复杂度就已经为O(n*n),如果真实数据的n非常大,这个算法的开销的确态度,所以KNN不适合大规模数据量的分类。

KNN算法编码时遇到的困难:

按理来说这么简单的KNN算法本应该是没有多少的难度,但是在多欧式距离的排序上被深深的坑了一段时间,本人起初用Collections.sort(list)的方式进行按距离排序,也把Sample类实现了Compareable接口,但是排序就是不变,最后才知道,distance的int类型要改为Integer引用类型,在compareTo重载方法中调用distance的.CompareTo()方法就成功了,这个小细节平时没注意,难道属性的比较最终一定要调用到引用类型的compareTo()方法?这个小问题竟然花费了我一段时间,最后仔细的比较了一下网上的例子最后才发现......

本文转载自:http://blog.csdn.net/xiaojimanman/article/details/51086879

下一篇: 动态规划
一贱书生
粉丝 20
博文 724
码字总数 600123
作品 0
私信 提问
Java 文本分类器集合 - text-classifier-collection

文本分类器集合 一个强大易用的Java文本分类工具包 特色 功能全面 内置信息检索中各种常用的文本预处理方法,如语言感知分词、词干提取、繁简转换、停用词去除、同义词插入、n-gram生成等等 ...

chanchungkwong
2018/05/21
0
0
Apache开源项目分类列表

分类 项目名 说明 开发语言 服务器 (共20) Apache HTTP Server 全球第一HTTP服务器 C/C++ Tomcat Java的Web服务器 Java James 邮件服务器 Java SpamAssassin 反垃圾邮件 C/C++ Perl Apach...

johnnyhg
2009/05/08
2.3K
0
25 个 Java 机器学习工具和库

本列表总结了25个Java机器学习工具&库: 1. Weka集成了数据挖掘工作的机器学习算法。这些算法可以直接应用于一个数据集上或者你可以自己编写代码来调用。Weka包括一系列的工具,如数据预处理...

oschina
2015/12/28
11.7K
11
JAVA接口的概念、分类及与抽象类的区别

Java接口(Interface),是一系列方法的声明,是一些方法特征的集合,一个接口只有方法的特征没有方法的实现,因此这些方法可以在不同的地方被不同的类实现,而这些实现可以具有不同的行为(...

郭二翔
2011/12/17
0
0
My java——JVM(内存域)三

续 My java——JVM(内存)二 写了一点JVM内存的一些操作的方法,和引出内存的分类。 是呀,java内存是我们在java编程中很少考虑到的,也没用真正的管理过。也许都知道JVM有自己的垃圾回收机...

tngou
2013/03/18
0
0

没有更多内容

加载失败,请刷新页面

加载更多

HBase新建表报错 org.apache.hadoop.hbase.TableExistsException

之前安装了旧版本的hbase, 没有清理其在Zookeeper上的内容。 解决办法 stop-hbase.sh zkCli.sh >>> rmr /hbase >>> quit start-hbase.sh...

dreamness
26分钟前
1
0
大数据技术的应用现状与展望

本文是我即将由嵌入式底层驱动行业转入大数据研究领域的综述文章,案例摘自《程序员》电子期刊,由于初学者知识面较窄,查看文献量较少,因此后续还会在此基础上,继续跟踪并深入研究,为论文...

陈小君
32分钟前
1
0
NCRE考试感想 三级信息安全(上)

时间节点 报名时间:2017-06 考试时间:2017-09 查询成绩:2017-11   考试简述 满分100分,时间120分钟。题型有三种,选择题、综合题、应用题。   备考经验 题库是WLJY的,买了激活码。为了...

志成就
39分钟前
1
0
百度地图显示我的位置

<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><title></title><script type="text/javascript" src="jquery-1.8.2.min.js"></script></head><body><sec......

塔塔米
43分钟前
2
0
mysql mysql常用的常用函数

1. 数学函数 函 数 作 用 ABS(x) 返回x的绝对值 CEIL(x),CEILIN(x) 返回不小于x的最小整数值 FLOOR(x) 返回不大于x的最大整数值 RAND() 返回0~1的随机数 RAND(x) 返回0~1的随机数,x值相同返...

edison_kwok
今天
2
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部