文档章节

【大数据分析常用算法】3.K-近邻算法

Areya
 Areya
发布于 2019/03/01 15:09
字数 2559
阅读 162
收藏 2

「深度学习福利」大神带你进阶工程师,立即查看>>>

简介

K-近邻(K-Nearest Neighbors, KNN)是一个非常简单的机器学习算法,很多机器学习算法书籍都喜欢将该算法作为入门的算法作为介绍。

KNN分类问题是找出一个数据集中与给定查询数据点最近的K个数据点。这个操作也成为KNN连接(KNN-join)。可以定义为:给定两个数据集R合S,对R中的每一个对象,我们希望从S中找出K个最近的相邻对象。

在数据挖掘中,R和S分别称为查询和训练(traning)数据集。训练数据集S表示已经分类的数据,而查询数据集R表示利用S中的分类来进行分类的数据。

KNN是一个比较重要的聚类算法,在数据挖掘(图像识别)、生物信息(如乳腺癌诊断)、天气数据生成模型和商品推荐系统中有很多应用。

缺点:开销大。特别是有一个庞大的训练集时。正是这个原因,使用MapReduce运行该算法显得非常的有用。

1、KNN算法

1.1、KNN分类

KNN的中心思想是建立一个分类方法,使得对于将y(响应变量)与x(预测变量)关联的平滑函数f的形势没有任何的假设: $$ x = (x_{1},x_{2},...,x_{n}) $$

$$ y = f(x) $$

函数f是非参数化的,因为它不涉及任何形式的参数估计。在KNN中,给定一个新的点$p=(p_{1},p_{2},...,p_{n})$,要动态的识别训练集数据集中与p相似的K个观察(k个近邻)。近邻由一个距离或相似度来定义。可以根据独立变量计算不同观察之间的距离,我们采用欧氏距离进行计算: $$ \sqrt{(x_{1} - p_{1})^2 + (x_{2} - p_{2})^2 + ... + (x_{n}-p_{n})^2} $$

关于距离的算法以及种类有很多,本章节我们采用欧氏距离,即坐标系距离计算方法。

那么如何找出k个近邻呢?

我们先计算出欧氏距离的集合,然后将这个查询对象分配到k个最近训练数据中大多数对象所在的类。

1.2、距离函数

假设有两个n维对象: $$ X = (X_{1},X_{2},...,X_{n}) $$

$$ Y = (Y_{1},Y_{2},...,Y_{n}) $$

$distance(X,Y)$可以定义如下: $$ distance(X,Y) = \sqrt{\sum_{i=1}^{n}(x_{i}-y_{i})^2} $$

注意欧氏距离只适用于连续性数值类型:double。如果是其他类型,则可以考虑关联业务情况下设置距离函数,将其转化为double类型。

关于所有的有关各种距离的介绍,参考博文:

1.3、KNN解析

KNN算法是一种对未分类数据进行分类的直观方法,他会根据未分类数据与训练数据集中的数据的相似度或距离完成分类。在下面的例子中,我们有4个分类$C_{1} - C_{4}$:

可以看到,我们的K=6,因此选取了6个近邻,在这6个近邻中,出现在上方的那个类中有4个属于它的点,因此,我们将P点归为上方圆圈包含的这一类型中。

1.4、算法描述

KNN算法可以总结为以下的步骤:

  1. 确定K
  2. 计算新输入与所有训练集之间的距离
  3. 对距离排序,并根据第k个最小距离确定k个近邻
  4. 手机这些近邻所属的类别
  5. 根据多数投票确定类别

算法复杂度:$O(N^2)$

2、Spark实现

2.1、形式化描述

设R和S是d维数据集,我们想找出其kNN(RS)。进一步假设所有训练数据(S)已经分类到$C={C_{1},C_{2},...,C_{n}}$,这里$C$表示所有可能的分类。R、S和C的定义如下: $$ R = {R_{1},R_{2},...,R_{n}} $$

$$ S = {S_{1},S_{2},...,S_{n}} $$

$$ C = {C_{1},C_{2},...,C_{n}} $$

在这里:

  1. $R_{i} = (r_{i},a_{1},a_{2},...,a_{n})$,其中$r_{i}$是当前记录的ID,$a_{1},...,a_{n}$是$R_{i}$的属性;
  2. $S_{j} = {r_{j},b_{1},b_{2},...,b_{n}}$同上。
  3. $C_{j}$是$S_{j}$的分类标识符。

我们的目标是找出$KNN(R,S)$。

2.2、数据集

S数据集如下所示:

100;c1;1.0,1.0
101;c1;1.1,1.2
102;c1;1.2,1.0
103;c1;1.6,1.5
104;c1;1.3,1.7
105;c1;2.0,2.1
106;c1;2.0,2.2
107;c1;2.3,2.3
208;c2;9.0,9.0
209;c2;9.1,9.2
210;c2;9.2,9.0
211;c2;10.6,10.5
212;c2;10.3,10.7
213;c2;9.6,9.1
214;c2;9.4,10.4
215;c2;10.3,10.3
300;c3;10.0,1.0
301;c3;10.1,1.2
302;c3;10.2,1.0
303;c3;10.6,1.5
304;c3;10.3,1.7
305;c3;1.0,2.1
306;c3;10.0,2.2
307;c3;10.3,2.3

其中,第一列为每条记录的唯一ID,第二列为该条记录的所属类别,之后的都为维度信息;

R数据集的信息如下:

1000;3.0,3.0
1001;10.1,3.2
1003;2.7,2.7
1004;5.0,5.0
1005;13.1,2.2
1006;12.7,12.7

其中,第一列为每条记录的唯一ID,之后的都为维度信息;

接下来我们使用KNN算法,来计算R数据集中每个记录所属的类别。

2.3、Spark实现


package com.sunrun.movieshow.autils.knn;

import com.google.common.base.Splitter;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;

import java.util.*;

public class KNNTester {
    /**
     * 1. 获取Spark 上下文对象
     * @return
     */
    public static JavaSparkContext getSparkContext(String appName){
        SparkConf sparkConf = new SparkConf()
                .setAppName(appName)
                //.setSparkHome(sparkHome)
                .setMaster("local[*]")
                // 串行化器
                .set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
                .set("spark.testing.memory", "2147480000");

        return new JavaSparkContext(sparkConf);
    }

    /**
     * 2. 将数字字符串转换为Double数组
     * @param str 数字字符串: "1,2,3,4,5"
     * @param delimiter 数字之间的分隔符:","
     * @return Double数组
     */
    public static List<Double> transferToDoubleList(String str, String delimiter){
        // 使用Google Splitter切割字符串
        Splitter splitter = Splitter.on(delimiter).trimResults();
        Iterable<String> tokens = splitter.split(str);
        if(tokens == null){
            return null;
        }
        List<Double> list = new ArrayList<>();
        for (String token : tokens) {
            list.add(Double.parseDouble(token));
        }
        return list;
    }

    /**
     * 计算距离
     * @param rRecord R数据集的一条记录
     * @param sRecord S数据集的一条记录
     * @param d 记录的维度
     * @return 两条记录的欧氏距离
     */
    public static double calculateDistance(String rRecord, String sRecord, int d){
        double distance = 0D;
        List<Double> r = transferToDoubleList(rRecord,",");
        List<Double> s = transferToDoubleList(sRecord,",");
        // 若维度不一致,说明数据存在问题,返回NAN
        if(r.size() != d || s.size() != d){
            distance =  Double.NaN;
        } else{
            // 保证维度一致之后,计算欧氏距离
            double sum = 0D;
            for (int i = 0; i < s.size(); i++) {
                double diff = s.get(i) - r.get(i);
                sum += diff * diff;
            }
            distance = Math.sqrt(sum);
        }
        return distance;
    }

    /**
     * 根据(距离,类别),找出距离最低的K个近邻
     * @param neighbors 当前求出的近邻数量
     * @param k 寻找多少个近邻
     * @return K个近邻组成的SortedMap
     */
    public static SortedMap<Double, String>findNearestK(Iterable<Tuple2<Double,String>> neighbors, int k){
        TreeMap<Double, String> kNeighbors = new TreeMap<>();
        for (Tuple2<Double, String> neighbor : neighbors) {
            // 距离
            Double distance = neighbor._1;
            // 类别
            String classify = neighbor._2;
            kNeighbors.put(distance, classify);
            // 如果当前已经写入K个元素,那么删除掉距离最远的一个元素(位于末端)
            if(kNeighbors.size() > k){
                kNeighbors.remove(kNeighbors.lastKey());
            }
        }
        return kNeighbors;
    }

    /**
     * 计算对每个类别的投票次数
     * @param kNeighbors 选取的K个最近的点
     * @return 对每个类别的投票结果
     */
    public static Map<String, Integer> buildClassifyCount(Map<Double, String> kNeighbors){
        HashMap<String, Integer> majority = new HashMap<>();
        for (Map.Entry<Double, String> entry : kNeighbors.entrySet()) {
            String classify = entry.getValue();
            Integer count = majority.get(classify);
            // 当前没有出现过,设置为1,否则+1
            if(count == null){
                majority.put(classify,1);
            }else{
                majority.put(classify,count + 1);
            }
        }
        return  majority;
    }

    /**
     * 根据投票结果,选取最终的类别
     * @param majority 投票结果
     * @return 最终的类别
     */
    public static String classifyByMajority(Map<String, Integer> majority){
        String selectedClassify = null;
        int maxVotes = 0;
        // 从投票结果中选取票数最多的一类作为最终选举结果
        for (Map.Entry<String, Integer> entry : majority.entrySet()) {
            if(selectedClassify == null){
                selectedClassify = entry.getKey();
                maxVotes = entry.getValue();
            }else{
                int nowVotes = entry.getValue();
                if(nowVotes > maxVotes){
                    selectedClassify = entry.getKey();
                    maxVotes = nowVotes;
                }
            }
        }
        return selectedClassify;
    }



    public static void main(String[] args) {
        // === 1.创建SparkContext
        JavaSparkContext sc = getSparkContext("KNN");

        // === 2.KNN算法相关参数:广播共享对象
        String HDFSUrl = "hdfs://10.21.1.24:9000/output/";
        // k(K)
        Broadcast<Integer> broadcastK = sc.broadcast(6);
        // d(维度)
        Broadcast<Integer> broadcastD = sc.broadcast(2);

        // === 3.为查询和训练数据集创建RDD
        // R and S
        String RPath = "data/knn/R.txt";
        String SPath = "data/knn/S.txt";
        JavaRDD<String> R = sc.textFile(RPath);
        JavaRDD<String> S = sc.textFile(SPath);
//        // === 将R和S的数据存储到hdfs
//        R.saveAsTextFile(HDFSUrl + "S");
//        S.saveAsTextFile(HDFSUrl + "R");

        // === 5.计算R&S的笛卡尔积
        JavaPairRDD<String, String> cart = R.cartesian(S);
        /**
         * (1000;3.0,3.0,100;c1;1.0,1.0)
         * (1000;3.0,3.0,101;c1;1.1,1.2)
         */

        // === 6.计算R中每个点与S各个点之间的距离:(rid,(distance,classify))
        // (1000;3.0,3.0,100;c1;1.0,1.0) => 1000 is rId, 100 is sId, c1 is classify.
        JavaPairRDD<String, Tuple2<Double, String>> knnPair = cart.mapToPair(t -> {
            String rRecord = t._1;
            String sRecord = t._2;

            // 1000;3.0,3.0
            String[] splitR = rRecord.split(";");
            String rId = splitR[0]; // 1000
            String r = splitR[1];// "3.0,3.0"

            // 100;c1;1.0,1.0
            String[] splitS = sRecord.split(";");
            // sId对于当前算法没有多大意义,我们只需要获取类别细信息,即第二个字段的信息即可
            String sId = splitS[0]; // 100
            String classify = splitS[1]; // c1
            String s = splitS[2];// "3.0,3.0"

            // 获取广播变量中的维度信息
            Integer d = broadcastD.value();
            // 计算当前两个点的距离
            double distance = calculateDistance(r, s, d);
            Tuple2<Double, String> V = new Tuple2<>(distance, classify);
            // (Rid,(distance,classify))
            return new Tuple2<>(rId, V);
        });
        /**
         * (1005,(2.801785145224379,c3))
         * (1006,(4.75078940808788,c2))
         * (1006,(4.0224370722237515,c2))
         * (1006,(3.3941125496954263,c2))
         * (1006,(12.0074976577137,c3))
         * (1006,(11.79025020938911,c3)
         */


        // === 7. 按R中的r根据每个记录进行分组
        JavaPairRDD<String, Iterable<Tuple2<Double, String>>> knnGrouped = knnPair.groupByKey();
        // (1005,[(12.159358535712318,c1),....,(7.3171032519706865,c3), (7.610519036176179,c3)]),
        // (1000,[(2.8284271247461903,c1), (2.6172504656604803,c1), (2.690724....])

        // === 8.找出每个R节点的k个近邻
        JavaPairRDD<String, String> knnOutput = knnGrouped.mapValues(t -> {
            // K
            Integer k = broadcastK.value();
            SortedMap<Double, String> nearestK = findNearestK(t, k);
            // {2.596150997149434=c3, 2.801785145224379=c3, 2.8442925306655775=c3, 3.0999999999999996=c3, 3.1384709652950433=c3, 3.1622776601683795=c3}

            // 统计每个类别的投票次数
            Map<String, Integer> majority = buildClassifyCount(nearestK);
            // {c3=1, c1=5}

            // 按多数优先原则选择最终分类
            String selectedClassify = classifyByMajority(majority);
            return selectedClassify;
        });

        // 存储最终结果
        knnOutput.saveAsTextFile(HDFSUrl + "/result");
        /**
         * [root@h24 hadoop]# hadoop fs -cat /output/result/p*
         * (1005,c3)
         * (1001,c3)
         * (1006,c2)
         * (1003,c1)
         * (1000,c1)
         * (1004,c1)
         */
    }
}

步骤7和8也可以通过reduceByKey或者CombineByKey进行一步到位。先来看看我们的转换过程:

RDD:
—— knnPair: JavaPairRDD<String, Tuple2<Double, String>>
—— knnGrouped: JavaPairRDD<String, Iterable<Tuple2<Double, String>>>
—— knnOutput:JavaPairRDD<String, String>

变换过程:

knnPair    => groupBy   => knnGrouped
knnGrouped => mapValues => knnOutput

显然,我们无法使用reduceByKey,因此他要求输出类型等同于输入类型。聚集的返回类型不同于聚集值的类型时就要使用combineByKey变换。因此,我们将使用combineByKey把步骤7和8合并到一起。这个合并步骤如下:

RDD:

—— knnPair: JavaPairRDD<String, Tuple2<Double, String>>
—— knnOutput: JavaPairRDD<String, String>

变换过程:

—— knnPair => combineByKey => knnOutput
Areya
粉丝 28
博文 101
码字总数 255789
作品 0
广州
私信 提问
加载中
请先登录后再评论。
Flappy Bird(安卓版)逆向分析(一)

更改每过一关的增长分数 反编译的步骤就不介绍了,我们直接来看反编译得到的文件夹 方法1:在smali目录下,我们看到org/andengine/,可以知晓游戏是由andengine引擎开发的。打开/res/raw/at...

enimey
2014/03/04
6.1K
18
实时分析系统--istatd

istatd是IMVU公司工程师开发的一款优秀的实时分析系统,能够有效地收集,存储和搜索各种分析指标,类似cacti,Graphite,Zabbix等系统。实际上,istatd修改了Graphite的存储后端,重新实现了...

匿名
2013/02/07
3.1K
1
日志分析平台 - Kibana

Kibana 是一个为 Logstash 和 ElasticSearch 提供的日志分析的 Web 接口。可使用它对日志进行高效的搜索、可视化、分析等各种操作。 环境要求: ruby >= 1.8.7 (probably?) bundler logstash...

匿名
2013/02/13
11.6W
1
Swing界面分析和调试工具--Swing Inspector

Swing Inspector是一个Java Swing/AWT用户界面分析和调试工具,功能与firebug类似,具有强大的Swing/AWT用户界面分析和调试相关功能。 适用于从java swing初级到高级的所有开发人员,能够快速...

匿名
2013/03/06
3.4K
0
硬实时操作系统--Raw OS

Raw-OS 起飞于2012年,Raw-OS志在制作中国人自己的最优秀硬实时操作系统。 Raw-OS 操作系统特性 内核最大关中断时间无限接近0us, s3c2440系统最大关中断时间实测0.8us。 支持idle任务级别的事...

jorya_txj
2013/03/19
6.3K
1

没有更多内容

加载失败,请刷新页面

加载更多

利用Numpy中的ascontiguousarray可以是数组在内存上连续,加速计算

1. 概述 在使用Numpy的时候,有时候会遇到下面的错误: AttributeError: incompatible shape for a non-contiguous array 看报错的字面意思,好像是不连续数组的shape不兼容。 有的时候,在看...

osc_9we1w99u
9分钟前
0
0
如何管理客户的期望值?

根据客户关系管理(CRM)中的三角定律,客户满意度=客户体验-客户期望值。客户期望值与客户满意度成相对反比,因此需要引导客户期望值并维持在一个适当的水平,同时客户期望值需要与客户体验协...

cailisuper
今天
0
0
阿里研究员:软件测试中的18个难题

阿里QA导读:对于软件测试来说,怎么样才算测够了?如何评价测试的有效性?那么多测试用例,以后怎么删?在软件测试中会遇到非常多的问题,阿里研究员郑子颖分享了18个他总结出的难题以及相关...

阿里巴巴技术质量
昨天
0
0
Numpy的常用函数总结

1、np.argmax()、np.max()、np.argmin()、np.min()用法: argmax返回的是最大数的索引.argmax有一个参数axis,默认是0。看二维的情况如下: a = np.array([[1, 5, 5, 2],            ...

osc_auwur47t
11分钟前
0
0
【报告分享】2020抖音进阶-挑战赛2.0产品营销方案.pdf(附下载链接)

大家好,我是文文(微信:sscbg2020),今天给大家分享抖音营销中心出品的《2020抖音进阶-挑战赛2.0产品营销方案.pdf》,方案里面的玩法解析、案例、营销重点分析等都很清晰,对短视频及品牌...

智能推荐系统
昨天
11
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部