文档章节

Spark2 Linear Regression线性回归

hblt-j
 hblt-j
发布于 2017/08/29 11:41
字数 1882
阅读 28
收藏 0

回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多重共线性情况下运行良好。

 

数学上,ElasticNet被定义为L1和L2正则化项的凸组合:

通过适当设置α,ElasticNet包含L1和L2正则化作为特殊情况。例如,如果用参数α设置为1来训练线性回归模型,则其等价于Lasso模型。另一方面,如果α被设置为0,则训练的模型简化为ridge回归模型。 

RegParam:lambda>=0
ElasticNetParam:alpha in [0, 1]

 

导入包

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.evaluation.RegressionEvaluator

import org.apache.spark.ml.regression.LinearRegression

 

导入样本数据

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

// Population人口,

// Income收入水平,

// Illiteracy文盲率,

// LifeExp,

// Murder谋杀率,

// HSGrad,

// Frost结霜天数(温度在冰点以下的平均天数) ,

// Area州面积

    val spark = SparkSession.builder().appName("Spark Linear Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

    // For implicit conversions like converting RDDs to DataFrames

    import spark.implicits._

 

    val dataList: List[(Double, Double, Double, Double, Double, Double, Double, Double)] = List(

      (3615, 3624, 2.1, 69.05, 15.1, 41.3, 20, 50708),

      (365, 6315, 1.5, 69.31, 11.3, 66.7, 152, 566432),

      (2212, 4530, 1.8, 70.55, 7.8, 58.1, 15, 113417),

      (2110, 3378, 1.9, 70.66, 10.1, 39.9, 65, 51945),

      (21198, 5114, 1.1, 71.71, 10.3, 62.6, 20, 156361),

      (2541, 4884, 0.7, 72.06, 6.8, 63.9, 166, 103766),

      (3100, 5348, 1.1, 72.48, 3.1, 56, 139, 4862),

      (579, 4809, 0.9, 70.06, 6.2, 54.6, 103, 1982),

      (8277, 4815, 1.3, 70.66, 10.7, 52.6, 11, 54090),

      (4931, 4091, 2, 68.54, 13.9, 40.6, 60, 58073),

      (868, 4963, 1.9, 73.6, 6.2, 61.9, 0, 6425),

      (813, 4119, 0.6, 71.87, 5.3, 59.5, 126, 82677),

      (11197, 5107, 0.9, 70.14, 10.3, 52.6, 127, 55748),

      (5313, 4458, 0.7, 70.88, 7.1, 52.9, 122, 36097),

      (2861, 4628, 0.5, 72.56, 2.3, 59, 140, 55941),

      (2280, 4669, 0.6, 72.58, 4.5, 59.9, 114, 81787),

      (3387, 3712, 1.6, 70.1, 10.6, 38.5, 95, 39650),

      (3806, 3545, 2.8, 68.76, 13.2, 42.2, 12, 44930),

      (1058, 3694, 0.7, 70.39, 2.7, 54.7, 161, 30920),

      (4122, 5299, 0.9, 70.22, 8.5, 52.3, 101, 9891),

      (5814, 4755, 1.1, 71.83, 3.3, 58.5, 103, 7826),

      (9111, 4751, 0.9, 70.63, 11.1, 52.8, 125, 56817),

      (3921, 4675, 0.6, 72.96, 2.3, 57.6, 160, 79289),

      (2341, 3098, 2.4, 68.09, 12.5, 41, 50, 47296),

      (4767, 4254, 0.8, 70.69, 9.3, 48.8, 108, 68995),

      (746, 4347, 0.6, 70.56, 5, 59.2, 155, 145587),

      (1544, 4508, 0.6, 72.6, 2.9, 59.3, 139, 76483),

      (590, 5149, 0.5, 69.03, 11.5, 65.2, 188, 109889),

      (812, 4281, 0.7, 71.23, 3.3, 57.6, 174, 9027),

      (7333, 5237, 1.1, 70.93, 5.2, 52.5, 115, 7521),

      (1144, 3601, 2.2, 70.32, 9.7, 55.2, 120, 121412),

      (18076, 4903, 1.4, 70.55, 10.9, 52.7, 82, 47831),

      (5441, 3875, 1.8, 69.21, 11.1, 38.5, 80, 48798),

      (637, 5087, 0.8, 72.78, 1.4, 50.3, 186, 69273),

      (10735, 4561, 0.8, 70.82, 7.4, 53.2, 124, 40975),

      (2715, 3983, 1.1, 71.42, 6.4, 51.6, 82, 68782),

      (2284, 4660, 0.6, 72.13, 4.2, 60, 44, 96184),

      (11860, 4449, 1, 70.43, 6.1, 50.2, 126, 44966),

      (931, 4558, 1.3, 71.9, 2.4, 46.4, 127, 1049),

      (2816, 3635, 2.3, 67.96, 11.6, 37.8, 65, 30225),

      (681, 4167, 0.5, 72.08, 1.7, 53.3, 172, 75955),

      (4173, 3821, 1.7, 70.11, 11, 41.8, 70, 41328),

      (12237, 4188, 2.2, 70.9, 12.2, 47.4, 35, 262134),

      (1203, 4022, 0.6, 72.9, 4.5, 67.3, 137, 82096),

      (472, 3907, 0.6, 71.64, 5.5, 57.1, 168, 9267),

      (4981, 4701, 1.4, 70.08, 9.5, 47.8, 85, 39780),

      (3559, 4864, 0.6, 71.72, 4.3, 63.5, 32, 66570),

      (1799, 3617, 1.4, 69.48, 6.7, 41.6, 100, 24070),

      (4589, 4468, 0.7, 72.48, 3, 54.5, 149, 54464),

      (376, 4566, 0.6, 70.29, 6.9, 62.9, 173, 97203))

 

    val data = dataList.toDF("Population", "Income", "Illiteracy", "LifeExp", "Murder", "HSGrad", "Frost", "Area")

 

建立线性回归模型

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

val colArray = Array("Population", "Income", "Illiteracy", "LifeExp", "HSGrad", "Frost", "Area")

 

val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")

 

val vecDF: DataFrame = assembler.transform(data)

 

// 建立模型,预测谋杀率Murder

// 设置线性回归参数

val lr1 = new LinearRegression()

val lr2 = lr1.setFeaturesCol("features").setLabelCol("Murder").setFitIntercept(true)

// RegParam:正则化

val lr3 = lr2.setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)

val lr = lr3

 

// Fit the model

val lrModel = lr.fit(vecDF)

 

// 输出模型全部参数

lrModel.extractParamMap()

// Print the coefficients and intercept for linear regression

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

 

val predictions = lrModel.transform(vecDF)

predictions.selectExpr("Murder", "round(prediction,1) as prediction").show

 

// Summarize the model over the training set and print out some metrics

val trainingSummary = lrModel.summary

println(s"numIterations: ${trainingSummary.totalIterations}")

println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")

trainingSummary.residuals.show()

println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")

println(s"r2: ${trainingSummary.r2}")

 

代码执行结果

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

// 输出模型全部参数

lrModel.extractParamMap()

res15: org.apache.spark.ml.param.ParamMap =

{

    linReg_2ba28140e39a-elasticNetParam: 0.8,

    linReg_2ba28140e39a-featuresCol: features,

    linReg_2ba28140e39a-fitIntercept: true,

    linReg_2ba28140e39a-labelCol: Murder,

    linReg_2ba28140e39a-maxIter: 10,

    linReg_2ba28140e39a-predictionCol: prediction,

    linReg_2ba28140e39a-regParam: 0.3,

    linReg_2ba28140e39a-solver: auto,

    linReg_2ba28140e39a-standardization: true,

    linReg_2ba28140e39a-tol: 1.0E-6

}

 

// Print the coefficients and intercept for linear regression

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [1.36662199778084E-4,0.0,1.1834384307116244,-1.4580829641757522,0.0,-0.010686434270049252,4.051355050528196E-6] Intercept: 109.589659881471

 

val predictions = lrModel.transform(vecDF)

predictions: org.apache.spark.sql.DataFrame = [Population: double, Income: double ... 8 more fields]

 

predictions.selectExpr("Murder", "round(prediction,1) as prediction").show

+------+----------+

|Murder|prediction|

+------+----------+

15.1|      11.9|

11.3|      11.0|

|   7.8|       9.5|

10.1|       8.6|

10.3|       9.6|

|   6.8|       4.3|

|   3.1|       4.2|

|   6.2|       7.5|

10.7|       9.3|

13.9|      12.3|

|   6.2|       4.7|

|   5.3|       4.6|

10.3|       8.8|

|   7.1|       6.6|

|   2.3|       3.5|

|   4.5|       3.9|

10.6|       8.9|

13.2|      13.2|

|   2.7|       6.3|

|   8.5|       7.8|

+------+----------+

only showing top 20 rows

 

// Summarize the model over the training set and print out some metrics

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.regression.LinearRegressionTrainingSummary = org.apache.spark.ml.regression.LinearRegressionTrainingSummary@68a83d76

 

println(s"numIterations: ${trainingSummary.totalIterations}")

numIterations: 11

 

println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")

objectiveHistory: List(0.49000000000000016, 0.3919242806809093, 0.19908078426904946, 0.1901453492751914, 0.17981874256031405, 0.17878173084286247, 0.1787617816935607, 0.17875431854661641, 0.1

7874702637141196, 0.17874512271568685, 0.1787449876896829)

trainingSummary.residuals.show()

+--------------------+

|           residuals|

+--------------------+

3.2200068116713023|

0.2745518816306607|

| -1.6535887417767414|

|   1.485762696757325|

0.6509766532389172|

|   2.457688146554534|

| -1.0675250558261182|

| -1.2879164685248439|

1.3672723619868314|

1.6125000289597242|

|   1.532060517905248|

0.6931301635074645|

1.5163001982000175|

| 0.46227066807431605|

| -1.2044058248740273|

0.6032541157521649|

|     1.7201545753635|

|-0.01942937427384...|

|  -3.632947522687547|

0.7077675962948007|

+--------------------+

only showing top 20 rows

 

println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")

RMSE: 1.6663615527314546

 

println(s"r2: ${trainingSummary.r2}")

r2: 0.7920794990832152

 

模型调优,用Train-Validation Split

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

val colArray = Array("Population", "Income", "Illiteracy", "LifeExp", "HSGrad", "Frost", "Area")

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray).setOutputCol("features").transform(data)

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

 

// 建立模型,预测谋杀率Murder,设置线性回归参数

val lr = new LinearRegression().setFeaturesCol("features").setLabelCol("Murder").fit(trainingDF)

 

// 设置管道

val pipeline = new Pipeline().setStages(Array(lr))

 

// 建立参数网格

val paramGrid = new ParamGridBuilder().addGrid(lr.fitIntercept).addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)).addGrid(lr.maxIter, Array(10, 100)).build()

 

// 选择(prediction, true label),计算测试误差。

// 注意RegEvaluator.isLargerBetter,评估的度量值是大的好,还是小的好,系统会自动识别

val RegEvaluator = new RegressionEvaluator().setLabelCol(lr.getLabelCol).setPredictionCol(lr.getPredictionCol).setMetricName("rmse")

 

val trainValidationSplit = new TrainValidationSplit().setEstimator(pipeline).setEvaluator(RegEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8) // 数据分割比例

 

// Run train validation split, and choose the best set of parameters.

val tvModel = trainValidationSplit.fit(trainingDF)

 

// 查看模型全部参数

tvModel.extractParamMap()

 

tvModel.getEstimatorParamMaps.length

tvModel.getEstimatorParamMaps.foreach { println } // 参数组合的集合

 

tvModel.getEvaluator.extractParamMap() // 评估的参数

 

tvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好

 

tvModel.getTrainRatio

 

// 用最好的参数组合,做出预测

tvModel.transform(testDF).select("features", "Murder", "prediction").show()

 

调优代码执行结果

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

// 查看模型全部参数

tvModel.extractParamMap()

res45: org.apache.spark.ml.param.ParamMap =

{

    tvs_5de7d3dd1977-estimator: pipeline_062a1dffe557,

    tvs_5de7d3dd1977-estimatorParamMaps: [Lorg.apache.spark.ml.param.ParamMap;@60298de1,

    tvs_5de7d3dd1977-evaluator: regEval_05204824acb9,

    tvs_5de7d3dd1977-seed: -1772833110,

    tvs_5de7d3dd1977-trainRatio: 0.8

}

 

tvModel.getEstimatorParamMaps.length

res46: Int = 12

 

tvModel.getEstimatorParamMaps.foreach { println } // 参数组合的集合

{

    linReg_75628a5554b4-elasticNetParam: 0.0,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 0.0,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 100

}

{

    linReg_75628a5554b4-elasticNetParam: 0.0,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 0.0,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 100

}

{

    linReg_75628a5554b4-elasticNetParam: 0.5,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 0.5,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 100

}

{

    linReg_75628a5554b4-elasticNetParam: 0.5,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 0.5,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 100

}

{

    linReg_75628a5554b4-elasticNetParam: 1.0,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 1.0,

    linReg_75628a5554b4-fitIntercept: true,

    linReg_75628a5554b4-maxIter: 100

}

{

    linReg_75628a5554b4-elasticNetParam: 1.0,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 10

}

{

    linReg_75628a5554b4-elasticNetParam: 1.0,

    linReg_75628a5554b4-fitIntercept: false,

    linReg_75628a5554b4-maxIter: 100

}

 

tvModel.getEvaluator.extractParamMap() // 评估的参数

res48: org.apache.spark.ml.param.ParamMap =

{

    regEval_05204824acb9-labelCol: Murder,

    regEval_05204824acb9-metricName: rmse,

    regEval_05204824acb9-predictionCol: prediction

}

 

tvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好

res49: Boolean = false

 

tvModel.getTrainRatio

res50: Double = 0.8

 

tvModel.transform(testDF).select("features", "Murder", "prediction").show()

+--------------------+------+------------------+

|            features|Murder|        prediction|

+--------------------+------+------------------+

|[1058.0,3694.0,0....|   2.7| 6.917232043935343|

|[2341.0,3098.0,2....|  12.5|14.760329005533478|

|[472.0,3907.0,0.6...|   5.5| 4.182074651181182|

|[812.0,4281.0,0.7...|   3.3| 4.915905572667441|

|[2816.0,3635.0,2....|  11.6|14.219231061596304|

|[4589.0,4468.0,0....|   3.0| 3.483554528704758|

+--------------------+------+------------------+

 

本文转载自:http://www.cnblogs.com/wwxbi/p/6028261.html

共有 人打赏支持
hblt-j
粉丝 20
博文 162
码字总数 67057
作品 0
海淀
架构师
私信 提问
销售预测案例源码分析

本文重在借案例学习spark相关数据结构与语法 流程 1. 特征转换 先转化为StringIndexer inputCol原始列 outputCol转化为对应的index列: 做OneHotEncoder 转化为对应向量 只指定一位为1,其余为...

AliThink
2017/11/09
0
0
echarts 里面的回归分析除了四种,还有其他的吗

echarts除了linear 线性回归 Exponential Regression 指数回归 Logarithmic Regression 对数回归 Polynomial Regression 多项式回归这四种,是否还有幂函数 、双曲线、一元多次、多元线性的列...

说完的话
2017/08/18
500
0
Coursera Machine Learning机器学习课程编程作业参考答案

coursera上的machine learning课程是一门很好的机器学习入门课程。这里将该课程的所有编程作业的答案分享给大家~ 所有作业提交均正确(线性回归练习中只做了必做的部分)。 该课程包含的编程...

zzlyw
2017/10/28
0
0
机器学习笔记-线性回归

Linear Regression Problem 信用额度预测 希望算法可以给出针对用户信用卡额度的预测,即当我们收集到用户的一些信息(如下),那么我们如何决定发放给该用户的信用额度呢? 线性回归 这个问...

robin_Xu_shuai
2017/08/23
0
0
判别式模型(discriminative model)和生成模型(generative model)

已知输入变量x,判别模型(discriminative model)通过求解条件概率分布P(y|x)或者直接计算y的值来预测y。生成模型(generative model)通过对观测值和标注数据计算联合概率分布P(x,y)来达到判...

致Great
05/15
0
0

没有更多内容

加载失败,请刷新页面

加载更多

Java开发中SpringCloud+Hystrix服务容错详细解析

Netflix Hystrix — 应对复杂分布式系统中的延时和故障容错 应用场景 分布式系统中经常会出现某个基础服务不可用造成整个系统不可用的情况, 这种现象被称为服务雪崩效应. 为了应对服务雪崩,...

金铭鼎IT教育
15分钟前
2
0
统计学习方法c++实现之一 感知机

  感知机      前言      最近学习了c++,俗话说‘光说不练假把式’,所以决定用c++将《统计学习方法》里面的经典模型全部实现一下,代码在这里,请大家多多指教。      感知机...

SEOwhywhy
17分钟前
0
0
python爬取虎嗅网数据

#!/usr/bin/env python# -*- coding:utf-8 -*-import requestsimport pymongofrom bs4 import BeautifulSoupclient = pymongo.MongoClient(host='localhost',port=27017)......

蜗牛奔跑
17分钟前
1
0
偷懒秘诀之变量篇

学习一个新语言的时候,总是苦恼的,例如:英文。但是作为主流语言,考试必考,又不能放弃,那我们就要选择一种好的学习方式啦~像是了解它的语法组成规则:“主谓宾定状补表同”。 JavaScri...

我的卡
18分钟前
2
0
锁分类(独占锁、分拆锁、分离锁、分布式锁)

在共享内存的多处理器体系架构中,每个处理器都拥有自己的缓存,并且定期地与主内存进行协调。 在不同的处理器架构中提供了不同级别的缓存一致性(Cache Coherence), 其中一部分只提供最小的...

Java搬砖工程师
19分钟前
2
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部