文档章节

spark GBT算法

tuoleisi77
 tuoleisi77
发布于 2017/07/12 14:46
字数 1037
阅读 54
收藏 0

梯度增强树(GBT)是使用决策树组合的流行回归方法

相对于Random forest 来说GBT在实际应用中,效果更好

直接上代码

package mllib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SparkSession

/**
  * Created by dongdong on 17/7/10.
  */

case class Fearture_One(
                         cid: String,
                         population_gender: String,
                         population_age: Double,
                         population_registered_gps_city: String,
                         population_education_nature: String,
                         population_university_level: String,
                         sociality_channel_type: String,
                         action_registered_channel: String,
                         action_this_month_once_week_average_login_count: Double,
                         population_censu_city: String,
                         population_gps_city: String,
                         population_own_cell_city: String,
                         population_rank1_cell_city: String,
                         population_rank1_cell_cnt: Double,
                         population_rank2_cell_city: String,
                         population_rank2_cell_cnt: Double,
                         population_rank3_cell_city: String,
                         population_rank3_cell_cnt: Double,
                         population_gps_censu_flag: Double,
                         population_own_censu_flag: Double,
                         population_gps_own_flag: Double,
                         population_own_txl_flag: Double,
                         population_gps_txl_flag: Double,
                         population_censu_txl_flag: Double,
                         population_cnt_7day_province: Double,
                         population_cnt_7day_city: Double,
                         population_cnt_login: Double,
                         population_before_apply_city: String,
                         population_after_apply_city: String,
                         population_before_in_apply_address: Double,
                         population_before_after_apply_address: Double,
                         population_in_after_apply_address: Double,
                         population_re_address_steady: String,
                         population_apply_address_steady: String,
                         population_score_fake_gps: Double,
                         population_score_fake_contacts: Double,
                         text: String,
                         flag: String
                       )

object GBT_Profile {

  def main(args: Array[String]): Unit = {
    
    val inpath1 = "/Users/ant_git/src/data/user_profile_train/part-00000"

    val spark = SparkSession
      .builder()
      .master("local[3]")
      .appName("GBT_Profile")
      .getOrCreate()
    import spark.implicits._

    //read data and transform datafram
    val originalData = spark.sparkContext
      .textFile(inpath1)
      .map(line => {
        val arr = line.split("\001")
        val cid = arr(0)
        val population_gender = arr(3).replace("\\N", "N")
        val population_age = arr(4).replace("\\N", "0").toDouble
        val population_registered_gps_city = arr(7).replace("\\N", "N")
        val population_education_nature = arr(10).replace("\\N", "N")
        val population_university_level = arr(11).replace("\\N", "N")
        val sociality_channel_type = arr(13).replace("\\N", "N")
        val action_registered_channel = arr(44).replace("\\N", "N")
        val action_this_month_once_week_average_login_count = arr(54).replace("\\N", "0").toDouble
        val population_censu_city = arr(63).replace("\\N", "N")
        val population_gps_city = arr(64).replace("\\N", "N")
        // val population_jz_city = arr(65).replace("\\N", "N")
        // val population_ip_city = arr(66).replace("\\N", "N")
        val population_own_cell_city = arr(67).replace("\\N", "N")
        val population_rank1_cell_city = arr(68).replace("\\N", "N")
        val population_rank1_cell_cnt = arr(69).replace("\\N", "0").toDouble
        val population_rank2_cell_city = arr(70).replace("\\N", "N")
        val population_rank2_cell_cnt = arr(71).replace("\\N", "0").toDouble
        val population_rank3_cell_city = arr(72).replace("\\N", "N")
        val population_rank3_cell_cnt = arr(73).replace("\\N", "0").toDouble
        //val population_jxl_call_max_city = arr(74).replace("\\N", "N")
        // val population_jxl_call_max_city_cnt = arr(75).replace("\\N", "0").toDouble
        //val population_anzhuo_30day_max_city = arr(76).replace("\\N", "N")
        //val population_anzhuo_30day_max_city_cnt = arr(77).replace("\\N", "0").toDouble
        val population_gps_censu_flag = arr(78).replace("\\N", "0").toDouble
        //val population_gps_jxl_flag = arr(79).replace("\\N", "0").toDouble
        //val population_gps_jz_flag = arr(80).replace("\\N", "0").toDouble
        //val population_ip_censu_flag = arr(81).replace("\\N", "0").toDouble
        // val population_ip_jxl_flag = arr(82).replace("\\N", "0").toDouble
        //val population_ip_jz_flag = arr(83).replace("\\N", "0").toDouble
        val population_own_censu_flag = arr(84).replace("\\N", "0").toDouble
        //val population_own_jxl_flag = arr(85).replace("\\N", "0").toDouble
        //val population_own_jz_flag = arr(86).replace("\\N", "0").toDouble
        val population_gps_own_flag = arr(87).replace("\\N", "0").toDouble
        //val population_gps_ip_flag = arr(88).replace("\\N", "0").toDouble
        //val population_ip_own_flag = arr(89).replace("\\N", "0").toDouble
        //val population_ip_txl_flag = arr(90).replace("\\N", "0").toDouble
        val population_own_txl_flag = arr(91).replace("\\N", "0").toDouble
        val population_gps_txl_flag = arr(92).replace("\\N", "0").toDouble
        val population_censu_txl_flag = arr(93).replace("\\N", "0").toDouble
        //val population_jxl_txl_flag = arr(94).replace("\\N", "0").toDouble
        //val population_jz_txl_flag = arr(95).replace("\\N", "0").toDouble
        val population_cnt_7day_province = arr(96).replace("\\N", "0").toDouble
        val population_cnt_7day_city = arr(97).replace("\\N", "0").toDouble
        val population_cnt_login = arr(102).replace("\\N", "0").toDouble
        val population_before_apply_city = arr(107).replace("\\N", "N")
        val population_after_apply_city = arr(108).replace("\\N", "N")
        val population_before_in_apply_address = arr(111).replace("\\N", "0").toDouble
        val population_before_after_apply_address = arr(112).replace("\\N", "0").toDouble
        val population_in_after_apply_address = arr(113).replace("\\N", "0").toDouble
        val population_re_address_steady = arr(116).replace("\\N", "N")
        val population_apply_address_steady = arr(117).replace("\\N", "N")
        val population_score_fake_gps = arr(127).replace("\\N", "0").toDouble
        val population_score_fake_contacts = arr(128).replace("\\N", "0").toDouble
        val text = population_gender + "|" +
          population_registered_gps_city + "|" +
          population_education_nature + "|" +
          population_university_level + "|" +
          sociality_channel_type + "|" +
          action_registered_channel + "|" +
          population_censu_city + "|" +
          population_gps_city + "|" +
          population_own_cell_city + "|" +
          population_rank1_cell_city + "|" +
          population_rank2_cell_city + "|" +
          population_rank3_cell_city + "|" +
          population_before_apply_city + "|" +
          population_after_apply_city + "|" +
          population_re_address_steady + "|" +
          population_apply_address_steady
        val flag = arr(141)
        Fearture_One(
          cid: String,
          population_gender: String,
          population_age: Double,
          population_registered_gps_city: String,
          population_education_nature: String,
          population_university_level: String,
          sociality_channel_type: String,
          action_registered_channel: String,
          action_this_month_once_week_average_login_count: Double,
          population_censu_city: String,
          population_gps_city: String,
          population_own_cell_city: String,
          population_rank1_cell_city: String,
          population_rank1_cell_cnt: Double,
          population_rank2_cell_city: String,
          population_rank2_cell_cnt: Double,
          population_rank3_cell_city: String,
          population_rank3_cell_cnt: Double,
          population_gps_censu_flag: Double,
          population_own_censu_flag: Double,
          population_gps_own_flag: Double,
          population_own_txl_flag: Double,
          population_gps_txl_flag: Double,
          population_censu_txl_flag: Double,
          population_cnt_7day_province: Double,
          population_cnt_7day_city: Double,
          population_cnt_login: Double,
          population_before_apply_city: String,
          population_after_apply_city: String,
          population_before_in_apply_address: Double,
          population_before_after_apply_address: Double,
          population_in_after_apply_address: Double,
          population_re_address_steady: String,
          population_apply_address_steady: String,
          population_score_fake_gps: Double,
          population_score_fake_contacts: Double,
          text: String,
          flag: String
        )
      }
      ).toDS


    //label to indexer
    val labelIndexer = new StringIndexer()
      .setInputCol("flag")
      .setOutputCol("indexedLabel")
      .fit(originalData)

    //splits words
    val tokenizer = new RegexTokenizer()
      .setInputCol("text")
      .setOutputCol("words")
      .setPattern("\\|")

    //words to vector
    val word2Vec = new Word2Vec()
      .setInputCol("words")
      .setOutputCol("word2feature")
      .setVectorSize(100)
      //.setMinCount(1)
      .setMaxIter(10)

    //array fields
    val arr = Array("population_age",
      "action_this_month_once_week_average_login_count",
      "population_rank1_cell_cnt",
      "population_rank2_cell_cnt",
      "population_rank3_cell_cnt",
      "population_gps_censu_flag",
      "population_own_censu_flag",
      "population_gps_own_flag",
      "population_own_txl_flag",
      "population_gps_txl_flag",
      "population_censu_txl_flag",
      "population_cnt_7day_province",
      "population_cnt_7day_city",
      "population_cnt_login",
      "population_before_in_apply_address",
      "population_before_after_apply_address",
      "population_in_after_apply_address",
      "population_score_fake_gps",
      "population_score_fake_contacts",
      "word2feature"
    )
    //merge fields to Verctor
    val vectorAssembler = new VectorAssembler()
      .setInputCols(arr)
      .setOutputCol("assemblerVector")

    //creat GBT
    val gbt = new GBTClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("assemblerVector")
      //set iterator
      .setMaxIter(25)
      //set tree depth
      .setMaxDepth(5)

    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    val Array(trainingData, testData) = originalData.randomSplit(Array(0.8, 0.2))

    val pipeline = new Pipeline().setStages(Array(labelIndexer, tokenizer, word2Vec, vectorAssembler, gbt, labelConverter))

    val model = pipeline.fit(originalData)

    val predictionResultDF = model.transform(testData)

    predictionResultDF.show(false)

    val label_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === 1)
      .count()

    val correct_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === $"predictedLabel")
      .filter($"predictedLabel" === 1).count()

    val correct_0 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === $"predictedLabel")
      .filter($"predictedLabel" === 0).count()

    val predicted_1 = predictionResultDF.select("cid", "predictedLabel")
      .filter($"predictedLabel" === 1)
      .repartition(1).write.format("csv").save("/Users/ant_git/Antifraud/src/data/predict/")

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictionResultDF)
    val error = 1.0 - accuracy
    println("Test Error = " + (1.0 - accuracy))

    spark.stop()
  }
}

总结:算法是别人封装好的,最重要的是特征如何进行处理,好的特征,很简单的算法都可以进行分类,不好的特征,再好的模型也很难有好的效果,所以如何进行特征的选择,对于机器学习来说是非常重要的

© 著作权归作者所有

tuoleisi77
粉丝 4
博文 28
码字总数 43810
作品 0
深圳
程序员
私信 提问
Spark之获取GBT二分类函数的概率值

  在Spark中,GBT(Gradient Boost Trees,提升树)函数用于实现机器学习中的提升树算法,目前仅支持二分类算法。笔者在实际工作中需要获得其预测的概率值,无奈该函数没有相应的方法。  ...

jclian91
2017/10/09
0
0
Apache Spark 1.6.2 发布,集群计算环境

Apache Spark 1.6.2 发布了,Apache Spark 是一种与 Hadoop 相似的开源集群计算环境,但是两者之间还存在一些不同之处,这些有用的不同之处使 Spark 在某些工作负载方面表现得更加优越,换句...

愚_者
2016/06/28
3.8K
1
Apache Spark 2.4.0 正式发布

Apache Spark 2.4 与昨天正式发布,Apache Spark 2.4 版本是 2.x 系列的第五个版本。 如果想及时了解 Spark、Hadoop或者Hbase相关的文章,欢迎关注微信公共帐号: itebloghadoop Apache Spa...

Spark
2018/11/09
0
0
学途无忧网的视频怎么破解下载?学途无忧网淘宝可以买吗?

学途无忧网的视频怎么破解下载?学途无忧网淘宝可以买吗? 我想下载这几集,或者低价购买这几集 课时7:SparkSQL java操作mysql数据 课时8:Spark统计用户的收藏转换率 课时9:Spark梳理用户...

天池番薯
2016/12/19
1K
2
Spark on Angel:Spark机器学习的核心加速器

Spark的核心概念是RDD,而RDD的关键特性之一是其不可变性,来规避分布式环境下复杂的各种并行问题。这个抽象,在数据分析的领域是没有问题的,它能最大化的解决分布式问题,简化各种算子的复...

腾讯开源
2017/08/01
3
0

没有更多内容

加载失败,请刷新页面

加载更多

64.监控平台介绍 安装zabbix 忘记admin密码

19.1 Linux监控平台介绍 19.2 zabbix监控介绍 19.3/19.4/19.6 安装zabbix 19.5 忘记Admin密码如何做 19.1 Linux监控平台介绍: 常见开源监控软件 ~1.cacti、nagios、zabbix、smokeping、ope...

oschina130111
昨天
64
0
当餐饮遇上大数据,嗯真香!

之前去开了一场会,主题是「餐饮领袖新零售峰会」。认真听完了餐饮前辈和新秀们的分享,觉得获益匪浅,把脑子里的核心纪要整理了一下,今天和大家做一个简单的分享,欢迎感兴趣的小伙伴一起交...

数澜科技
昨天
26
0
DNS-over-HTTPS 的下一代是 DNS ON BLOCKCHAIN

本文作者:PETER LAI ,是 Diode 的区块链工程师。在进入软件开发领域之前,他主要是在做工商管理相关工作。Peter Lai 也是一位活跃的开源贡献者。目前,他正在与 Diode 团队一起开发基于区块...

红薯
昨天
43
0
CC攻击带来的危害我们该如何防御?

随着网络的发展带给我们很多的便利,但是同时也带给我们一些网站安全问题,网络攻击就是常见的网站安全问题。其中作为站长最常见的就是CC攻击,CC攻击是网络攻击方式的一种,是一种比较常见的...

云漫网络Ruan
昨天
27
0
实验分析性专业硕士提纲撰写要点

为什么您需要研究论文的提纲? 首先当您进行研究时,您需要聚集许多信息和想法,研究论文提纲可以较好地组织你的想法, 了解您研究资料的流畅度和程度。确保你写作时不会错过任何重要资料以此...

论文辅导员
昨天
44
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部