文档章节

LogisticRegression 逻辑回归之建模

hblt-j
 hblt-j
发布于 2017/08/29 14:11
字数 1164
阅读 17
收藏 0

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression }

import org.apache.spark.ml.tuning.{ ParamGridBuilder, TrainValidationSplit }

val spark = SparkSession.builder().appName("Spark Logistic Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

// For implicit conversions like converting RDDs to DataFrames

import spark.implicits._

 

val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( 

      (0, "male", 37, 10, "no", 3, 18, 7, 4), 

      (0, "female", 27, 4, "no", 4, 14, 6, 4), 

 

sqlDF.show()

+-------+------+----+------------+--------+-------------+---------+----------+------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|

+-------+------+----+------------+--------+-------------+---------+----------+------+

|      0|     1|37.0|        10.0|       0|          3.0|     18.0|       7.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       6.0|   4.0|

|      0|     0|32.0|        15.0|       1|          1.0|     12.0|       1.0|   4.0|

|      0|     1|57.0|        15.0|       1|          5.0|     18.0|       6.0|   5.0|

|      0|     1|22.0|        0.75|       0|          2.0|     17.0|       6.0|   3.0|

|      0|     0|32.0|         1.5|       0|          2.0|     17.0|       5.0|   5.0|

|      0|     0|22.0|        0.75|       0|          2.0|     12.0|       1.0|   3.0|

|      0|     1|57.0|        15.0|       1|          2.0|     14.0|       4.0|   4.0|

|      0|     0|32.0|        15.0|       1|          4.0|     16.0|       1.0|   2.0|

|      0|     1|22.0|         1.5|       0|          4.0|     14.0|       4.0|   5.0|

|      0|     1|37.0|        15.0|       1|          2.0|     20.0|       7.0|   2.0|

|      0|     1|27.0|         4.0|       1|          4.0|     18.0|       6.0|   4.0|

|      0|     1|47.0|        15.0|       1|          5.0|     17.0|       6.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     17.0|       5.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       5.0|   4.0|

|      0|     0|37.0|        15.0|       1|          1.0|     17.0|       5.0|   5.0|

|      0|     0|37.0|        15.0|       1|          2.0|     18.0|       4.0|   3.0|

|      0|     0|22.0|        0.75|       0|          3.0|     16.0|       5.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     16.0|       5.0|   5.0|

|      0|     0|27.0|        10.0|       1|          2.0|     14.0|       1.0|   5.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+

only showing top 20 rows

 

val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

colArray2: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(sqlDF)

vecDF: org.apache.spark.sql.DataFrame = [affairs: int, gender: int ... 8 more fields]

 

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

 

val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)

lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9d8a91cb1a0b

 

 

// 输出逻辑回归的系数和截距

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [0.308688148697453,-0.04150802586369178,0.08771801000466706,0.6896853841812993,-0.3425440049065515,0.008629892776596084,0.0458687806620022,-0.46268114569065383] Intercept: 1.263

200227888706

  

 

// 设置ElasticNet混合参数,范围为[0,1]。

// 对于α= 0,惩罚是L2惩罚。 对于alpha = 1,它是一个L1惩罚。 对于0 <α<1,惩罚是L1和L2的组合。 默认值为0.0,这是一个L2惩罚。

lrModel.getElasticNetParam

res5: Double = 0.0

 

lrModel.getRegParam  // 正则化参数>=0

res6: Double = 0.0

 

lrModel.getStandardization  // 在拟合模型之前,是否标准化特征

res7: Boolean = true

 

// 在二进制分类中设置阈值,范围为[0,1]。如果类标签1的估计概率>Threshold,则预测1,否则0.高阈值鼓励模型更频繁地预测0; 低阈值鼓励模型更频繁地预测1。默认值为0.5。

lrModel.getThreshold

res8: Double = 0.5

 

// 设置迭代的收敛容限。 较小的值将导致更高的精度与更多的迭代的成本。 默认值为1E-6。

lrModel.getTol

res9: Double = 1.0E-6

 

 

lrModel.transform(testDF).show

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|      0|     0|22.0|       0.125|       0|          4.0|     14.0|       4.0|   5.0|[0.0,22.0,0.125,0...|[3.01829971642105...|[0.95339403355398...|       0.0|

|      0|     0|22.0|       0.417|       1|          3.0|     14.0|       3.0|   5.0|[0.0,22.0,0.417,1...|[2.00632544907384...|[0.88145961149358...|       0.0|

|      0|     0|27.0|         1.5|       0|          2.0|     16.0|       6.0|   5.0|[0.0,27.0,1.5,0.0...|[2.31114222529279...|[0.90979563879849...|       0.0|

|      0|     0|27.0|         4.0|       1|          3.0|     18.0|       4.0|   5.0|[0.0,27.0,4.0,1.0...|[1.81918359677719...|[0.86046813628746...|       0.0|

|      0|     0|27.0|         7.0|       1|          2.0|     18.0|       1.0|   5.0|[0.0,27.0,7.0,1.0...|[1.35109190384264...|[0.79430808378365...|       0.0|

|      0|     0|27.0|         7.0|       1|          3.0|     16.0|       1.0|   4.0|[0.0,27.0,7.0,1.0...|[1.24821454861173...|[0.77699063797650...|       0.0|

|      0|     0|27.0|        10.0|       1|          2.0|     12.0|       1.0|   4.0|[0.0,27.0,10.0,1....|[0.67703608479756...|[0.66307686153089...|       0.0|

|      0|     0|32.0|        10.0|       1|          4.0|     17.0|       5.0|   4.0|[0.0,32.0,10.0,1....|[1.34303963739813...|[0.79298936429536...|       0.0|

|      0|     0|32.0|        10.0|       1|          5.0|     14.0|       4.0|   5.0|[0.0,32.0,10.0,1....|[2.22002324698713...|[0.90203325004083...|       0.0|

|      0|     0|32.0|        15.0|       1|          3.0|     18.0|       5.0|   4.0|[0.0,32.0,15.0,1....|[0.55327568969165...|[0.63489524159656...|       0.0|

|      0|     0|37.0|        15.0|       1|          4.0|     17.0|       1.0|   5.0|[0.0,37.0,15.0,1....|[1.75814598503192...|[0.85297730582863...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|      9.0|       5.0|   5.0|[0.0,52.0,15.0,1....|[2.60887439745861...|[0.93143054154558...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|     12.0|       1.0|   3.0|[0.0,52.0,15.0,1....|[1.84109755039552...|[0.86307846107252...|       0.0|

|      0|     0|57.0|        15.0|       1|          4.0|     16.0|       6.0|   4.0|[0.0,57.0,15.0,1....|[1.90491134608169...|[0.87044638395268...|       0.0|

|      0|     1|22.0|         4.0|       0|          1.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.26168391246747...|[0.77931584772929...|       0.0|

|      0|     1|22.0|         4.0|       0|          2.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.60422791737402...|[0.83260846569570...|       0.0|

|      0|     1|27.0|         4.0|       1|          3.0|     16.0|       5.0|   5.0|[1.0,27.0,4.0,1.0...|[1.48188645297092...|[0.81485734920851...|       0.0|

|      0|     1|27.0|         4.0|       1|          4.0|     14.0|       5.0|   4.0|[1.0,27.0,4.0,1.0...|[1.37900909774001...|[0.79883180985416...|       0.0|

|      0|     1|32.0|       0.125|       1|          2.0|     18.0|       5.0|   2.0|[1.0,32.0,0.125,1...|[0.28148664352576...|[0.56991065665974...|       0.0|

|      0|     1|32.0|        10.0|       1|          2.0|     20.0|       6.0|   3.0|[1.0,32.0,10.0,1....|[-0.1851761257948...|[0.45383780246566...|       1.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

only showing top 20 rows

 

 

 

// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier

// example

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@4cde233d

 

 

// Obtain the objective per iteration.

val objectiveHistory = trainingSummary.objectiveHistory

objectiveHistory: Array[Double] = Array(0.5613118243072733, 0.5564125149222438, 0.5365395467216898, 0.5160918427628939, 0.51304621799159, 0.5105231964507352, 0.5079869547558363, 0.50728888730

31864, 0.5067113660796532, 0.506520677080951, 0.5059147658563949, 0.5053652033316485, 0.5047266888422277, 0.5045473900598205, 0.5041496504941453, 0.5034630545828777, 0.5025745763542784, 0.5019910559468922, 0.5012033102192196, 0.5009489760675826, 0.5008431925740259, 0.5008297629370251, 0.5008258245513862, 0.5008137617093257, 0.5008136785235711, 0.5008130045533166, 0.5008129888367148, 0.5008129675120628, 0.5008129469652479, 0.5008129168191972, 0.5008129132692991, 0.5008129124596163, 0.5008129124081014, 0.500812912251931, 0.5008129121356268)

objectiveHistory.foreach(loss => println(loss))

0.5613118243072733

0.5564125149222438

0.5365395467216898

0.5160918427628939

0.51304621799159

0.5105231964507352

0.5079869547558363

0.5072888873031864

0.5067113660796532

0.506520677080951

0.5059147658563949

0.5053652033316485

0.5047266888422277

0.5045473900598205

0.5041496504941453

0.5034630545828777

0.5025745763542784

0.5019910559468922

0.5012033102192196

0.5009489760675826

0.5008431925740259

0.5008297629370251

0.5008258245513862

0.5008137617093257

0.5008136785235711

0.5008130045533166

0.5008129888367148

0.5008129675120628

0.5008129469652479

0.5008129168191972

0.5008129132692991

0.5008129124596163

0.5008129124081014

0.500812912251931

0.5008129121356268

 

 

 lrModel.transform(testDF).select("features","rawPrediction","probability","prediction").show(30,false)

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|features                             |rawPrediction                               |probability                             |prediction|

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|[3.0182997164210517,-3.0182997164210517]    |[0.9533940335539883,0.04660596644601167]|0.0       |

|[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|[2.00632544907384,-2.00632544907384]        |[0.8814596114935873,0.11854038850641263]|0.0       |

|[0.0,27.0,1.5,0.0,2.0,16.0,6.0,5.0]  |[2.311142225292793,-2.311142225292793]      |[0.9097956387984996,0.09020436120150035]|0.0       |

|[0.0,27.0,4.0,1.0,3.0,18.0,4.0,5.0]  |[1.81918359677719,-1.81918359677719]        |[0.8604681362874618,0.13953186371253828]|0.0       |

|[0.0,27.0,7.0,1.0,2.0,18.0,1.0,5.0]  |[1.351091903842644,-1.351091903842644]      |[0.7943080837836515,0.20569191621634847]|0.0       |

|[0.0,27.0,7.0,1.0,3.0,16.0,1.0,4.0]  |[1.2482145486117338,-1.2482145486117338]    |[0.7769906379765039,0.2230093620234961] |0.0       |

|[0.0,27.0,10.0,1.0,2.0,12.0,1.0,4.0] |[0.6770360847975654,-0.6770360847975654]    |[0.6630768615308953,0.33692313846910465]|0.0       |

|[0.0,32.0,10.0,1.0,4.0,17.0,5.0,4.0] |[1.343039637398138,-1.343039637398138]      |[0.7929893642953615,0.20701063570463848]|0.0       |

|[0.0,32.0,10.0,1.0,5.0,14.0,4.0,5.0] |[2.220023246987134,-2.220023246987134]      |[0.9020332500408325,0.09796674995916752]|0.0       |

|[0.0,32.0,15.0,1.0,3.0,18.0,5.0,4.0] |[0.5532756896916551,-0.5532756896916551]    |[0.6348952415965647,0.3651047584034352] |0.0       |

|[0.0,37.0,15.0,1.0,4.0,17.0,1.0,5.0] |[1.7581459850319243,-1.7581459850319243]    |[0.8529773058286395,0.14702269417136052]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,9.0,5.0,5.0]  |[2.6088743974586124,-2.6088743974586124]    |[0.9314305415455806,0.06856945845441945]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,12.0,1.0,3.0] |[1.8410975503955256,-1.8410975503955256]    |[0.8630784610725231,0.13692153892747697]|0.0       |

|[0.0,57.0,15.0,1.0,4.0,16.0,6.0,4.0] |[1.904911346081691,-1.904911346081691]      |[0.8704463839526814,0.1295536160473186] |0.0       |

|[1.0,22.0,4.0,0.0,1.0,18.0,5.0,5.0]  |[1.2616839124674724,-1.2616839124674724]    |[0.7793158477292919,0.22068415227070803]|0.0       |

|[1.0,22.0,4.0,0.0,2.0,18.0,5.0,5.0]  |[1.6042279173740237,-1.6042279173740237]    |[0.832608465695705,0.16739153430429493] |0.0       |

|[1.0,27.0,4.0,1.0,3.0,16.0,5.0,5.0]  |[1.4818864529709268,-1.4818864529709268]    |[0.8148573492085158,0.1851426507914842] |0.0       |

|[1.0,27.0,4.0,1.0,4.0,14.0,5.0,4.0]  |[1.379009097740017,-1.379009097740017]      |[0.7988318098541624,0.2011681901458377] |0.0       |

|[1.0,32.0,0.125,1.0,2.0,18.0,5.0,2.0]|[0.28148664352576547,-0.28148664352576547]  |[0.569910656659749,0.430089343340251]   |0.0       |

|[1.0,32.0,10.0,1.0,2.0,20.0,6.0,3.0] |[-0.1851761257948623,0.1851761257948623]    |[0.45383780246566996,0.5461621975343299]|1.0       |

|[1.0,32.0,10.0,1.0,4.0,20.0,6.0,4.0] |[0.9625930297088949,-0.9625930297088949]    |[0.7236406723848533,0.2763593276151468] |0.0       |

|[1.0,32.0,15.0,1.0,1.0,16.0,5.0,5.0] |[0.039440462424945366,-0.039440462424945366]|[0.5098588376463971,0.4901411623536029] |0.0       |

|[1.0,37.0,4.0,1.0,1.0,18.0,5.0,4.0]  |[0.7319377705508958,-0.7319377705508958]    |[0.6752303588678488,0.3247696411321513] |0.0       |

|[1.0,37.0,15.0,1.0,5.0,20.0,5.0,4.0] |[1.119955894572572,-1.119955894572572]      |[0.7539805352533917,0.24601946474660835]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,17.0,6.0,5.0] |[1.4276540623429193,-1.4276540623429193]    |[0.8065355283195409,0.19346447168045908]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,4.0,5.0] |[1.4935019453371354,-1.4935019453371354]    |[0.8166033137058254,0.1833966862941747] |0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,6.0,3.0] |[0.4764020926318233,-0.4764020926318233]    |[0.6168979221749373,0.38310207782506256]|0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0] |[1.0201325344483316,-1.0201325344483316]    |[0.734998414766428,0.265001585233572]   |0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,7.0,2.0] |[-0.04283609891898266,0.04283609891898266]  |[0.48929261249695394,0.5107073875030461]|1.0       |

|[1.0,57.0,15.0,1.0,5.0,20.0,5.0,3.0] |[1.4874352661557535,-1.4874352661557535]    |[0.8156930079647114,0.18430699203528864]|0.0       |

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

only showing top 30 rows

本文转载自:http://www.cnblogs.com/wwxbi/p/6224670.html

共有 人打赏支持
hblt-j
粉丝 19
博文 127
码字总数 63579
作品 0
海淀
架构师
数据挖掘-逻辑Logistic回归

逻辑回归的基本过程:a建立回归或者分类模型--->b 建立代价函数 ---> c 优化方法迭代求出最优的模型参数 --->d 验证求解模型的好坏。 1.逻辑回归模型: 逻辑回归(Logistic Regression):既...

蜘蛛侠不会飞
07/19
0
0
【python数据挖掘课程】十六.逻辑回归LogisticRegression分析鸢尾花数据

今天是教师节,容我先感叹下。 祝天下所有老师教师节快乐,这是自己的第二个教师节,这一年来,无限感慨,有给一个人的指导,有给十几个人讲毕设,有几十人的实验,有上百人的课堂,也有给上...

Eastmount
2017/09/10
0
0
《统计学习方法》笔记(五)逻辑斯蒂回归与最大熵模型

LR回归(Logistic Regression) LR回归,虽然这个算法从名字上来看,是回归算法,但其实际上是一个分类算法。在机器学习算法中,有几十种分类器,LR回归是其中最常用的一个。 LR回归是在线性...

ch1209498273
05/23
0
0
机器学习--第四讲-评估二元分类的简介

1.数据的介绍 在之前的任务中,我们学习了有关分类,逻辑回归,并且学习对于研究生入学申请的数据,来怎么使用scikit-learn 来拟合一个逻辑回归模型。我们将持续使用这个包含644 个申请人的数...

Betty__
2016/10/26
81
0
数学推导+纯Python实现机器学习算法:逻辑回归

自本系列第一讲推出以来,得到了不少同学的反响和赞成,也有同学留言说最好能把数学推导部分写的详细点,笔者只能说尽力,因为打公式实在是太浪费时间了。。本节要和大家一起学习的是逻辑(l...

酒逢知己千杯少
10/13
0
0

没有更多内容

加载失败,请刷新页面

加载更多

关于Jackson默认丢失Bigdecimal精度问题分析

问题描述 最近在使用一个内部的RPC框架时,发现如果使用Object类型,实际类型为BigDecimal的时候,作为传输对象的时候,会出现丢失精度的问题;比如在序列化前为金额1.00,反序列化之后为1.0...

ksfzhaohui
18分钟前
0
0
vue less安装

$ npm install less less-loader --save 安装成功后修改文件:build>webpack.base.conf.js 在model.rules添加对象: { test: /\.less$/, loader: "style-loader!css-loader!less-loade......

shawnDream
23分钟前
0
0
kolla-ansible部署容器ceph

kolla是从openstack孵化出的一个项目,kolla项目可以制作镜像包括openstack、ceph等容器镜像, ansible是自动化部署工具,执行playbook中的任务。 kolla-ansible是容器部署工具,部署opensta...

zrz11
28分钟前
0
0
【三 异步HTTP编程】 1. 处理异步results

异步results 事实上整个Play框架都是异步的。Play非阻塞地处理每个request请求。 默认的配置适配的正是异步的controller。因此开发者应该尽力避免在在controller中阻塞,如在controller方法中...

Landas
30分钟前
0
0
Android Studio 3.1.4 buildApk遇到问题 Connection reset

打开设置,找到Android Studio选项卡,把下图选项打上勾就ok

lanyu96
31分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部