spark GBT算法

原创
2017/07/12 14:46
阅读数 79

梯度增强树(GBT)是使用决策树组合的流行回归方法

相对于Random forest 来说GBT在实际应用中,效果更好

直接上代码

package mllib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SparkSession

/**
  * Created by dongdong on 17/7/10.
  */

case class Fearture_One(
                         cid: String,
                         population_gender: String,
                         population_age: Double,
                         population_registered_gps_city: String,
                         population_education_nature: String,
                         population_university_level: String,
                         sociality_channel_type: String,
                         action_registered_channel: String,
                         action_this_month_once_week_average_login_count: Double,
                         population_censu_city: String,
                         population_gps_city: String,
                         population_own_cell_city: String,
                         population_rank1_cell_city: String,
                         population_rank1_cell_cnt: Double,
                         population_rank2_cell_city: String,
                         population_rank2_cell_cnt: Double,
                         population_rank3_cell_city: String,
                         population_rank3_cell_cnt: Double,
                         population_gps_censu_flag: Double,
                         population_own_censu_flag: Double,
                         population_gps_own_flag: Double,
                         population_own_txl_flag: Double,
                         population_gps_txl_flag: Double,
                         population_censu_txl_flag: Double,
                         population_cnt_7day_province: Double,
                         population_cnt_7day_city: Double,
                         population_cnt_login: Double,
                         population_before_apply_city: String,
                         population_after_apply_city: String,
                         population_before_in_apply_address: Double,
                         population_before_after_apply_address: Double,
                         population_in_after_apply_address: Double,
                         population_re_address_steady: String,
                         population_apply_address_steady: String,
                         population_score_fake_gps: Double,
                         population_score_fake_contacts: Double,
                         text: String,
                         flag: String
                       )

object GBT_Profile {

  def main(args: Array[String]): Unit = {
    
    val inpath1 = "/Users/ant_git/src/data/user_profile_train/part-00000"

    val spark = SparkSession
      .builder()
      .master("local[3]")
      .appName("GBT_Profile")
      .getOrCreate()
    import spark.implicits._

    //read data and transform datafram
    val originalData = spark.sparkContext
      .textFile(inpath1)
      .map(line => {
        val arr = line.split("\001")
        val cid = arr(0)
        val population_gender = arr(3).replace("\\N", "N")
        val population_age = arr(4).replace("\\N", "0").toDouble
        val population_registered_gps_city = arr(7).replace("\\N", "N")
        val population_education_nature = arr(10).replace("\\N", "N")
        val population_university_level = arr(11).replace("\\N", "N")
        val sociality_channel_type = arr(13).replace("\\N", "N")
        val action_registered_channel = arr(44).replace("\\N", "N")
        val action_this_month_once_week_average_login_count = arr(54).replace("\\N", "0").toDouble
        val population_censu_city = arr(63).replace("\\N", "N")
        val population_gps_city = arr(64).replace("\\N", "N")
        // val population_jz_city = arr(65).replace("\\N", "N")
        // val population_ip_city = arr(66).replace("\\N", "N")
        val population_own_cell_city = arr(67).replace("\\N", "N")
        val population_rank1_cell_city = arr(68).replace("\\N", "N")
        val population_rank1_cell_cnt = arr(69).replace("\\N", "0").toDouble
        val population_rank2_cell_city = arr(70).replace("\\N", "N")
        val population_rank2_cell_cnt = arr(71).replace("\\N", "0").toDouble
        val population_rank3_cell_city = arr(72).replace("\\N", "N")
        val population_rank3_cell_cnt = arr(73).replace("\\N", "0").toDouble
        //val population_jxl_call_max_city = arr(74).replace("\\N", "N")
        // val population_jxl_call_max_city_cnt = arr(75).replace("\\N", "0").toDouble
        //val population_anzhuo_30day_max_city = arr(76).replace("\\N", "N")
        //val population_anzhuo_30day_max_city_cnt = arr(77).replace("\\N", "0").toDouble
        val population_gps_censu_flag = arr(78).replace("\\N", "0").toDouble
        //val population_gps_jxl_flag = arr(79).replace("\\N", "0").toDouble
        //val population_gps_jz_flag = arr(80).replace("\\N", "0").toDouble
        //val population_ip_censu_flag = arr(81).replace("\\N", "0").toDouble
        // val population_ip_jxl_flag = arr(82).replace("\\N", "0").toDouble
        //val population_ip_jz_flag = arr(83).replace("\\N", "0").toDouble
        val population_own_censu_flag = arr(84).replace("\\N", "0").toDouble
        //val population_own_jxl_flag = arr(85).replace("\\N", "0").toDouble
        //val population_own_jz_flag = arr(86).replace("\\N", "0").toDouble
        val population_gps_own_flag = arr(87).replace("\\N", "0").toDouble
        //val population_gps_ip_flag = arr(88).replace("\\N", "0").toDouble
        //val population_ip_own_flag = arr(89).replace("\\N", "0").toDouble
        //val population_ip_txl_flag = arr(90).replace("\\N", "0").toDouble
        val population_own_txl_flag = arr(91).replace("\\N", "0").toDouble
        val population_gps_txl_flag = arr(92).replace("\\N", "0").toDouble
        val population_censu_txl_flag = arr(93).replace("\\N", "0").toDouble
        //val population_jxl_txl_flag = arr(94).replace("\\N", "0").toDouble
        //val population_jz_txl_flag = arr(95).replace("\\N", "0").toDouble
        val population_cnt_7day_province = arr(96).replace("\\N", "0").toDouble
        val population_cnt_7day_city = arr(97).replace("\\N", "0").toDouble
        val population_cnt_login = arr(102).replace("\\N", "0").toDouble
        val population_before_apply_city = arr(107).replace("\\N", "N")
        val population_after_apply_city = arr(108).replace("\\N", "N")
        val population_before_in_apply_address = arr(111).replace("\\N", "0").toDouble
        val population_before_after_apply_address = arr(112).replace("\\N", "0").toDouble
        val population_in_after_apply_address = arr(113).replace("\\N", "0").toDouble
        val population_re_address_steady = arr(116).replace("\\N", "N")
        val population_apply_address_steady = arr(117).replace("\\N", "N")
        val population_score_fake_gps = arr(127).replace("\\N", "0").toDouble
        val population_score_fake_contacts = arr(128).replace("\\N", "0").toDouble
        val text = population_gender + "|" +
          population_registered_gps_city + "|" +
          population_education_nature + "|" +
          population_university_level + "|" +
          sociality_channel_type + "|" +
          action_registered_channel + "|" +
          population_censu_city + "|" +
          population_gps_city + "|" +
          population_own_cell_city + "|" +
          population_rank1_cell_city + "|" +
          population_rank2_cell_city + "|" +
          population_rank3_cell_city + "|" +
          population_before_apply_city + "|" +
          population_after_apply_city + "|" +
          population_re_address_steady + "|" +
          population_apply_address_steady
        val flag = arr(141)
        Fearture_One(
          cid: String,
          population_gender: String,
          population_age: Double,
          population_registered_gps_city: String,
          population_education_nature: String,
          population_university_level: String,
          sociality_channel_type: String,
          action_registered_channel: String,
          action_this_month_once_week_average_login_count: Double,
          population_censu_city: String,
          population_gps_city: String,
          population_own_cell_city: String,
          population_rank1_cell_city: String,
          population_rank1_cell_cnt: Double,
          population_rank2_cell_city: String,
          population_rank2_cell_cnt: Double,
          population_rank3_cell_city: String,
          population_rank3_cell_cnt: Double,
          population_gps_censu_flag: Double,
          population_own_censu_flag: Double,
          population_gps_own_flag: Double,
          population_own_txl_flag: Double,
          population_gps_txl_flag: Double,
          population_censu_txl_flag: Double,
          population_cnt_7day_province: Double,
          population_cnt_7day_city: Double,
          population_cnt_login: Double,
          population_before_apply_city: String,
          population_after_apply_city: String,
          population_before_in_apply_address: Double,
          population_before_after_apply_address: Double,
          population_in_after_apply_address: Double,
          population_re_address_steady: String,
          population_apply_address_steady: String,
          population_score_fake_gps: Double,
          population_score_fake_contacts: Double,
          text: String,
          flag: String
        )
      }
      ).toDS


    //label to indexer
    val labelIndexer = new StringIndexer()
      .setInputCol("flag")
      .setOutputCol("indexedLabel")
      .fit(originalData)

    //splits words
    val tokenizer = new RegexTokenizer()
      .setInputCol("text")
      .setOutputCol("words")
      .setPattern("\\|")

    //words to vector
    val word2Vec = new Word2Vec()
      .setInputCol("words")
      .setOutputCol("word2feature")
      .setVectorSize(100)
      //.setMinCount(1)
      .setMaxIter(10)

    //array fields
    val arr = Array("population_age",
      "action_this_month_once_week_average_login_count",
      "population_rank1_cell_cnt",
      "population_rank2_cell_cnt",
      "population_rank3_cell_cnt",
      "population_gps_censu_flag",
      "population_own_censu_flag",
      "population_gps_own_flag",
      "population_own_txl_flag",
      "population_gps_txl_flag",
      "population_censu_txl_flag",
      "population_cnt_7day_province",
      "population_cnt_7day_city",
      "population_cnt_login",
      "population_before_in_apply_address",
      "population_before_after_apply_address",
      "population_in_after_apply_address",
      "population_score_fake_gps",
      "population_score_fake_contacts",
      "word2feature"
    )
    //merge fields to Verctor
    val vectorAssembler = new VectorAssembler()
      .setInputCols(arr)
      .setOutputCol("assemblerVector")

    //creat GBT
    val gbt = new GBTClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("assemblerVector")
      //set iterator
      .setMaxIter(25)
      //set tree depth
      .setMaxDepth(5)

    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    val Array(trainingData, testData) = originalData.randomSplit(Array(0.8, 0.2))

    val pipeline = new Pipeline().setStages(Array(labelIndexer, tokenizer, word2Vec, vectorAssembler, gbt, labelConverter))

    val model = pipeline.fit(originalData)

    val predictionResultDF = model.transform(testData)

    predictionResultDF.show(false)

    val label_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === 1)
      .count()

    val correct_1 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === $"predictedLabel")
      .filter($"predictedLabel" === 1).count()

    val correct_0 = predictionResultDF.select("cid", "flag", "predictedLabel")
      .filter($"flag" === $"predictedLabel")
      .filter($"predictedLabel" === 0).count()

    val predicted_1 = predictionResultDF.select("cid", "predictedLabel")
      .filter($"predictedLabel" === 1)
      .repartition(1).write.format("csv").save("/Users/ant_git/Antifraud/src/data/predict/")

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictionResultDF)
    val error = 1.0 - accuracy
    println("Test Error = " + (1.0 - accuracy))

    spark.stop()
  }
}

总结:算法是别人封装好的,最重要的是特征如何进行处理,好的特征,很简单的算法都可以进行分类,不好的特征,再好的模型也很难有好的效果,所以如何进行特征的选择,对于机器学习来说是非常重要的

展开阅读全文
打赏
1
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
1
分享
返回顶部
顶部