文档章节

Spark上的决策树(Decision Tree On Spark)

hblt-j
 hblt-j
发布于 2017/09/05 12:00
字数 3174
阅读 16
收藏 0

实现概要

在陷入实现细节之前,我们先从全局大方面上来把握一下MLlib是如何实现分布式决策树的。

  • 首先,MLlib认为,决策树是随机森林(RandomForest)的一种特殊情况,也就是只有一棵树并且不采取特征抽样的随机森林。所以在训练决策树的时候,其实是训练随机森林,最后从随机森林中抽出一棵树。
  • 为了减少分布式训练过程中遍历数据的次数和提高训练速度,实现上采取了以下几个优化技巧:
    • 以广度优先方式建树(传统的实现是递归版本的深度优先方式)
    • 广度优先获得在maxMemory限制下的队列中的节点,作为一组,按组训练,这样每一次遍历数据需要做更多的计算和更多的存储空间,但是相应地减少了网络通信
    • 提前计算特征的切割点(Split)和切割区间(Bin),在数据量大的情况下,可以近似按Bin寻找最优切割点,而不用遍历训练数据的所有可能分割点
    • 利用已知的Bin和每个Bin需要的统计量个数构造一个一维数组,进行分区统计再合并

优化技巧剖析

训练随机森林获取决策树

其实这算不上什么优化技巧,为了逻辑上连贯,还是加上了:-),什么都不用说,直接上代码,请看DecisionTree中的run方法便知:

  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
    // Note: random seed will not be used since numTrees = 1.
    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
    val rfModel = rf.run(input)
    rfModel.trees(0)
  }

按层建树

其实这个技巧挺简单的,只要你知道如何按层打印二叉树的节点就可以,这种可是经典的面试题来的-:),很简单,只需利用一个队列来辅助即可。每次将队列的节点全部拿出来,按顺序处理每个节点,并将产生的新节点重新进队列,直到队列为空。

queue.enqueue(root)
while !queue.isEmpty:
    nodes = extractNodes(queue) //取出队列中的节点
    for node in nodes:
        growTree(node,queue)  //此处子节点可能会进入queue

这样便实现了按层建树的过程。MLlib中将这样每一次迭代从queue中获取的节点归为一组,并考虑每一组是需要用到的内存是否满足最大的内存限制,所以并不是每一次迭代都取整层的节点。也就是说每一组可能有不同层次的节点,因为是训练随机森林,所以每一组的节点可能来源不止一棵树。通过分组的操作,每一次遍历数据,可以操作当前组的所有节点,而不是只处理一个节点,从而减少了数据的遍历次数。

提前计算特征的切割点(Split)和切割区间(Bin)

遍历一遍输入数据或采样数据,我们就可以提前知道所有可能的分裂点。决策树的一个优势是可以处理连续特征(Continuous feature)和类别特征(Category feature)。

对于类别特征,比如一个颜色特征,它的特征值可以是:红,黄,蓝,绿。假设一个类别特征的特征值数目为N,因为一个特征可以同时取得多个特征值,比如红蓝,蓝绿,那么分割点其实就是特征值的所有可能组合,其个数为:2N,对于二叉决策树而言,有一半的分裂点其实是重复的,比如选红蓝为分割点和选黄绿为分割点其实是一样的,所以必须除以2,也就是2N−1,对于其中一种情况是取得所有特征值或者一种都取不到的情况,必须排除,所以最终的分裂点Split个数就是:2N−1−1,而区间Bin个数就是2N−2。具体实现中,MLlib采用了一个可证明的技巧(详请查阅《The Elements Of Statistical Learning》9.2.4节),对于二元分类问题,分裂点Split个数直接设为N−1,Bin的个数为N

对于连续的实数特征,如果数据集很小的情况下,通常会将特征的取值进行排序,然后遍历每个取值,尝试分裂,选出满足信息增益最大或者Gini-index最纯的分裂点。在数据量很大的情况下,对每个可能取值进行排序成本就太大了,一个惯用的近似技巧——“箱化” (Binning,我们觉得取值太多了,一个个处理太麻烦,打包起来,处理就简单多了)。假如原来特征的取值是:

1,2,3,4,5,6,7,8,9,....

经过打包后可能就变成如下:

(1,2),(3,4),(5,6),(7,8),....

然后以箱子为单位,根据每个箱子的最小值和最大值,可以确定划分边界,然后按照信息增益或者其他衡量方式确定最终分裂边界。那么采用这种技巧之后,Bin个数最多就是所有取值的情况的个数,也就是N(相当于没有使用”箱化”技巧),而Split的个数就是N−1。(感谢网友提出此处的错误,原来是错误认为Bin个数为N+1, Split个数是N)。

 

但是对于海量数据,或者一个无序的特征有太多的特征值,就有必要控制Bin的个数了,所以一个近似的做法就是,提前确定好箱子Bin的个数,也就是为什么在决策树设定参数中有maxBins这个参数的来由了。对于类别特征如果特征值超过maxBins,那么将分裂箱子Bin的数量退化为特征值的个数。对于连续的特征,如果不同训练特征少于maxBins,那么还是按照前面分析的做法,如果超过了,Bin的个数就设为maxBins,并采取尽量平均的方式选择切割点,使得每个Bin尽量包含相同个数的训练数据。如果训练数据实在太多,可以使用采样的方式,利用采样部分数据作为训练数据再使用上面的方法确定Split和Bin。由于采取了分区间的操作和可能的采样手段,必然降低了决策树的预测精度,但是另一方面却可以大大提升训练速度。实际中据说这样的技巧也没损伤多少精度-:)。

以上分析位于代码DecisionTreeMetadata.buildMetadata方法和DecisionTree.findSplitsBins方法中。其中DecisionTreeMetadata.buildMetadata设定了无序特征的分裂数目,而DecisionTree.findSplitsBins则确定了连续特征的分裂数并且实际生成连续特征、有序类别特征、无序类别特征的分裂对象(Split)和分裂区间(Bin)

数据切割

正常的决策树算法在某个节点选到最佳split之后,会将数据以此切割成两部分,然后继续递归去寻找分裂节点,对于mllib的决策树实现来说,这种递归实现做法是不可行的,那么如何将数据进行切割呢?其实也很简单,假设此时决策树已经构建了一部分,然后对于一个新的节点,我们想在此节点上通过信息增益或者其他函数选择最优的split,那么如何计算落在该节点的数据呢?一个简单的方式就是模拟决策树的预测过程,一个数据从根节点开始走预测流程,如果最终落到我们目标节点上,那么该数据就是参与接下来计算最佳分裂算法的所需数据之一,mllib上在binSeqOp函数中就是通过predictNodeIndex函数实现对数据的切割的。

按节点分区计算部分统计量

因为Spark的RDD数据是以Partition分区存储的。所以如果能先利用分区计算部分统计量,最后再合并统计量,就可以减少很多不必要的通信开销。那么该怎样分区统计并且使得后面合并的时候方便呢?MLlib的具体实现是为每个节点创建一个一维数组allStats作为统计的容器,怎样一个一维数组呢?由于上一步的计算,我们已经提前知道每个特征对应的Bin的个数了,那么每个Bin里面到底需要多少统计量呢?对于分类问题,假设是二元分类,那么每个Bin其实只有2个统计量,就是计算落到这个Bin里面正负样本的个数。而多类分类问题,分类个数N,则每个Bin里面就需要N个统计量。给个图直观展示一下,对于一个3类分类决策树,构造这样的一维数组allStats形式如下:
这里写图片描述
每个Bin都有3个类别的count,并顺序排列下去组成一个大的一维数组。这样的大数组涵盖了我们计算的所有可能性,为每个节点创建这样一个数组,都会消耗一定内存,所以设置maxBins需要小心。

既然我们提前知道每个特征对应的Bin的个数和每个Bin需要的统计量个数,我们可以设置一个数组featureOffsets,大小是featureNum,从0开始,累加每个feature对应的Bin数目,也就是进行Cumulative sum 操作。这样数组最后一个元素值就是总的Bin的个数totalBins。计算featureOffsets的代码如下:

  private val featureOffsets: Array[Int] = {
    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
  }

设置这样的偏移数组的好处就是一旦获得featureIndex,我们很方便查询到具体在上面大数组中的偏移量。给定一个featureIndex,还有binIndex,假设每个Bin的统计量为statSize,那么在大数组的更新偏移量为:

offset=featureOffsets(featureIndex)+binIndexstatSize

 

整理一下,我们之前已经事先计算每一个特征可能落入的Bin和切割点Split。也就是说一个特征取得不同的值将可能会落入不同的Bin中,但是对应具体的一个训练数据LabeledPoint,在每个特征上的取值已经是确定了,那么对于该LabeledPoint,我们可以事先计算每个特征对应落到那个Bin中,由于我们之前计算连续特征对应的Split和Bin的时候是有序的,那么可以利用二分查找寻找每个具体特征值的对应Bin,这也大大地节约了计算量。对于Categorical特征,则特征值就作为Bin的Index。具体可参考TreePoint的实现。LabeledPoint到TreePoint的转换其实就是将LabeledPoint里面的每一个特征值映射到每个Bin的Index,对于每一个TreePoint,我们因此可以知道它所有落入不同Bin的位置并更新那个Bin的统计量。
这样复杂的计算过程,在MLlib中实现抽象为类DTStatsAggregator,每个节点都有对应的DTStatsAggregator,DTStatsAggregator中包含前面介绍的allStats和featureOffsets。用于计算不同Bin在各个RDD分区的部分统计,最后再由reduceByKey合并起来,变成一个充分统计。示意如下:
这里写图片描述

因为这个技巧需要综合之前所有技巧,并且为了效率,实现上没有过多抽象,读起源码来难度会比较大。所以读一两遍读不懂不要气馁-:)。在真正分裂节点的时候,Continuous Feature和Categorical Feature是在计算信息增益的形式是有所不同的,并且都运用了Cumulative sum的技巧,但是这已经不是为了实现分布式决策树的技巧,这里就不再赘述。

结语

其实单单靠文字很难表达清楚整个实现过程,但是本文也差不多点出了MLlib中DecisionTree的核心要点,我并不希冀读者通过阅读本文就可以完全理解,但是可以根据本文点出的概念,再阅读源码,读5遍左右(我就是读了5遍),应该可以完全理解了-:)。决策树是随机森林和梯度提升树的基础,理解了决策树,再看其他两种模型,都是可以秒懂的-:)。

参考引用

《Scalable Distributed Decision Trees in Spark MLlib》
《PLANET: Massively Parallel Learning of Tree Ensembles with MapReduce》
官方文档Decision Trees - spark.mllib

本文转载自:http://blog.csdn.net/aws3217150/article/details/51909792

共有 人打赏支持
hblt-j
粉丝 16
博文 116
码字总数 56931
作品 0
海淀
架构师
利用KNIME建立Spark Machine learning模型 2:泰坦尼克幸存预测

本文利用KNIME基于Spark决策树模型算法,通过对泰坦尼克的包含乘客及船员的特征属性的训练数据集进行训练,得出决策树幸存模型,并利用测试数据集对模型进行测试。 1、从Kaggle网站下载训练...

forestwater
05/09
0
0
Spark之获取GBT二分类函数的概率值

  在Spark中,GBT(Gradient Boost Trees,提升树)函数用于实现机器学习中的提升树算法,目前仅支持二分类算法。笔者在实际工作中需要获得其预测的概率值,无奈该函数没有相应的方法。  ...

jclian91
2017/10/09
0
0
Spark 学习资源收集【Updating】

(一)spark 相关安装部署、开发环境 1、Spark 伪分布式 & 全分布式 安装指南 http://my.oschina.net/leejun2005/blog/394928 2、Apache Spark探秘:三种分布式部署方式比较 http://dongxic...

大数据之路
2014/09/08
0
1
Spark的39个机器学习库-英文

Apache Spark itself 1. MLlib AMPLab Spark originally came out of Berkeley AMPLab and even today AMPLab projects, even though they are not in Apache Spark Foundation, enjoy a sta......

MoksMo
2015/11/04
0
1
Spark的39个机器学习库-中文

//Apache Spark 本身// 1.MLlib >AMPLab Spark最初诞生于伯克利 AMPLab实验室,如今依然还是AMPLab所致力的项目,尽管这些不处于Apache Spark Foundation中,但是依然在你日常的github项目中...

MoksMo
2015/11/04
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

(一)软件测试专题——之Linux常用命令篇01

本文永久更新地址:https://my.oschina.net/bysu/blog/1931063 【若要到岸,请摇船:开源中国 不最醉不龟归】 Linux的历史之类的很多书籍都习惯把它的今生来世,祖宗十八代都扒出来,美其名曰...

不最醉不龟归
14分钟前
3
0
蚂蚁金服Java开发三面

8月20号晚上8点进行了蚂蚁金服Java开发岗的第三面,下面开始: 自我介绍(要求从实践过程以及技术背景角度着重介绍) 实习经历,说说你在公司实习所做的事情,学到了什么 关于你们的交易平台...

edwardGe
21分钟前
7
0
TypeScript基础入门 - 函数 - this(三)

转载 TypeScript基础入门 - 函数 - this(三) 项目实践仓库 https://github.com/durban89/typescript_demo.gittag: 1.2.4 为了保证后面的学习演示需要安装下ts-node,这样后面的每个操作都能...

durban
30分钟前
0
0
Spark core基础

Spark RDD的五大特性 RDD是由一系列的Partition组成的,如果Spark计算的数据是在HDFS上那么partition个数是与block数一致(大多数情况) RDD是有一系列的依赖关系,有利于Spark计算的容错 RDD中每...

张泽立
38分钟前
0
0
如何搭建Keepalived+Nginx+Tomcat高可用负载均衡架构

一.概述 初期的互联网企业由于业务量较小,所以一般单机部署,实现单点访问即可满足业务的需求,这也是最简单的部署方式,但是随着业务的不断扩大,系统的访问量逐渐的上升,单机部署的模式已...

Java大蜗牛
53分钟前
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部