文档章节

LogisticRegression 逻辑回归之建模

hblt-j
 hblt-j
发布于 2017/08/29 14:11
字数 1164
阅读 19
收藏 0

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression }

import org.apache.spark.ml.tuning.{ ParamGridBuilder, TrainValidationSplit }

val spark = SparkSession.builder().appName("Spark Logistic Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

// For implicit conversions like converting RDDs to DataFrames

import spark.implicits._

 

val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( 

      (0, "male", 37, 10, "no", 3, 18, 7, 4), 

      (0, "female", 27, 4, "no", 4, 14, 6, 4), 

 

sqlDF.show()

+-------+------+----+------------+--------+-------------+---------+----------+------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|

+-------+------+----+------------+--------+-------------+---------+----------+------+

|      0|     1|37.0|        10.0|       0|          3.0|     18.0|       7.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       6.0|   4.0|

|      0|     0|32.0|        15.0|       1|          1.0|     12.0|       1.0|   4.0|

|      0|     1|57.0|        15.0|       1|          5.0|     18.0|       6.0|   5.0|

|      0|     1|22.0|        0.75|       0|          2.0|     17.0|       6.0|   3.0|

|      0|     0|32.0|         1.5|       0|          2.0|     17.0|       5.0|   5.0|

|      0|     0|22.0|        0.75|       0|          2.0|     12.0|       1.0|   3.0|

|      0|     1|57.0|        15.0|       1|          2.0|     14.0|       4.0|   4.0|

|      0|     0|32.0|        15.0|       1|          4.0|     16.0|       1.0|   2.0|

|      0|     1|22.0|         1.5|       0|          4.0|     14.0|       4.0|   5.0|

|      0|     1|37.0|        15.0|       1|          2.0|     20.0|       7.0|   2.0|

|      0|     1|27.0|         4.0|       1|          4.0|     18.0|       6.0|   4.0|

|      0|     1|47.0|        15.0|       1|          5.0|     17.0|       6.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     17.0|       5.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       5.0|   4.0|

|      0|     0|37.0|        15.0|       1|          1.0|     17.0|       5.0|   5.0|

|      0|     0|37.0|        15.0|       1|          2.0|     18.0|       4.0|   3.0|

|      0|     0|22.0|        0.75|       0|          3.0|     16.0|       5.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     16.0|       5.0|   5.0|

|      0|     0|27.0|        10.0|       1|          2.0|     14.0|       1.0|   5.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+

only showing top 20 rows

 

val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

colArray2: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(sqlDF)

vecDF: org.apache.spark.sql.DataFrame = [affairs: int, gender: int ... 8 more fields]

 

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

 

val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)

lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9d8a91cb1a0b

 

 

// 输出逻辑回归的系数和截距

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [0.308688148697453,-0.04150802586369178,0.08771801000466706,0.6896853841812993,-0.3425440049065515,0.008629892776596084,0.0458687806620022,-0.46268114569065383] Intercept: 1.263

200227888706

  

 

// 设置ElasticNet混合参数,范围为[0,1]。

// 对于α= 0,惩罚是L2惩罚。 对于alpha = 1,它是一个L1惩罚。 对于0 <α<1,惩罚是L1和L2的组合。 默认值为0.0,这是一个L2惩罚。

lrModel.getElasticNetParam

res5: Double = 0.0

 

lrModel.getRegParam  // 正则化参数>=0

res6: Double = 0.0

 

lrModel.getStandardization  // 在拟合模型之前,是否标准化特征

res7: Boolean = true

 

// 在二进制分类中设置阈值,范围为[0,1]。如果类标签1的估计概率>Threshold,则预测1,否则0.高阈值鼓励模型更频繁地预测0; 低阈值鼓励模型更频繁地预测1。默认值为0.5。

lrModel.getThreshold

res8: Double = 0.5

 

// 设置迭代的收敛容限。 较小的值将导致更高的精度与更多的迭代的成本。 默认值为1E-6。

lrModel.getTol

res9: Double = 1.0E-6

 

 

lrModel.transform(testDF).show

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|      0|     0|22.0|       0.125|       0|          4.0|     14.0|       4.0|   5.0|[0.0,22.0,0.125,0...|[3.01829971642105...|[0.95339403355398...|       0.0|

|      0|     0|22.0|       0.417|       1|          3.0|     14.0|       3.0|   5.0|[0.0,22.0,0.417,1...|[2.00632544907384...|[0.88145961149358...|       0.0|

|      0|     0|27.0|         1.5|       0|          2.0|     16.0|       6.0|   5.0|[0.0,27.0,1.5,0.0...|[2.31114222529279...|[0.90979563879849...|       0.0|

|      0|     0|27.0|         4.0|       1|          3.0|     18.0|       4.0|   5.0|[0.0,27.0,4.0,1.0...|[1.81918359677719...|[0.86046813628746...|       0.0|

|      0|     0|27.0|         7.0|       1|          2.0|     18.0|       1.0|   5.0|[0.0,27.0,7.0,1.0...|[1.35109190384264...|[0.79430808378365...|       0.0|

|      0|     0|27.0|         7.0|       1|          3.0|     16.0|       1.0|   4.0|[0.0,27.0,7.0,1.0...|[1.24821454861173...|[0.77699063797650...|       0.0|

|      0|     0|27.0|        10.0|       1|          2.0|     12.0|       1.0|   4.0|[0.0,27.0,10.0,1....|[0.67703608479756...|[0.66307686153089...|       0.0|

|      0|     0|32.0|        10.0|       1|          4.0|     17.0|       5.0|   4.0|[0.0,32.0,10.0,1....|[1.34303963739813...|[0.79298936429536...|       0.0|

|      0|     0|32.0|        10.0|       1|          5.0|     14.0|       4.0|   5.0|[0.0,32.0,10.0,1....|[2.22002324698713...|[0.90203325004083...|       0.0|

|      0|     0|32.0|        15.0|       1|          3.0|     18.0|       5.0|   4.0|[0.0,32.0,15.0,1....|[0.55327568969165...|[0.63489524159656...|       0.0|

|      0|     0|37.0|        15.0|       1|          4.0|     17.0|       1.0|   5.0|[0.0,37.0,15.0,1....|[1.75814598503192...|[0.85297730582863...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|      9.0|       5.0|   5.0|[0.0,52.0,15.0,1....|[2.60887439745861...|[0.93143054154558...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|     12.0|       1.0|   3.0|[0.0,52.0,15.0,1....|[1.84109755039552...|[0.86307846107252...|       0.0|

|      0|     0|57.0|        15.0|       1|          4.0|     16.0|       6.0|   4.0|[0.0,57.0,15.0,1....|[1.90491134608169...|[0.87044638395268...|       0.0|

|      0|     1|22.0|         4.0|       0|          1.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.26168391246747...|[0.77931584772929...|       0.0|

|      0|     1|22.0|         4.0|       0|          2.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.60422791737402...|[0.83260846569570...|       0.0|

|      0|     1|27.0|         4.0|       1|          3.0|     16.0|       5.0|   5.0|[1.0,27.0,4.0,1.0...|[1.48188645297092...|[0.81485734920851...|       0.0|

|      0|     1|27.0|         4.0|       1|          4.0|     14.0|       5.0|   4.0|[1.0,27.0,4.0,1.0...|[1.37900909774001...|[0.79883180985416...|       0.0|

|      0|     1|32.0|       0.125|       1|          2.0|     18.0|       5.0|   2.0|[1.0,32.0,0.125,1...|[0.28148664352576...|[0.56991065665974...|       0.0|

|      0|     1|32.0|        10.0|       1|          2.0|     20.0|       6.0|   3.0|[1.0,32.0,10.0,1....|[-0.1851761257948...|[0.45383780246566...|       1.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

only showing top 20 rows

 

 

 

// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier

// example

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@4cde233d

 

 

// Obtain the objective per iteration.

val objectiveHistory = trainingSummary.objectiveHistory

objectiveHistory: Array[Double] = Array(0.5613118243072733, 0.5564125149222438, 0.5365395467216898, 0.5160918427628939, 0.51304621799159, 0.5105231964507352, 0.5079869547558363, 0.50728888730

31864, 0.5067113660796532, 0.506520677080951, 0.5059147658563949, 0.5053652033316485, 0.5047266888422277, 0.5045473900598205, 0.5041496504941453, 0.5034630545828777, 0.5025745763542784, 0.5019910559468922, 0.5012033102192196, 0.5009489760675826, 0.5008431925740259, 0.5008297629370251, 0.5008258245513862, 0.5008137617093257, 0.5008136785235711, 0.5008130045533166, 0.5008129888367148, 0.5008129675120628, 0.5008129469652479, 0.5008129168191972, 0.5008129132692991, 0.5008129124596163, 0.5008129124081014, 0.500812912251931, 0.5008129121356268)

objectiveHistory.foreach(loss => println(loss))

0.5613118243072733

0.5564125149222438

0.5365395467216898

0.5160918427628939

0.51304621799159

0.5105231964507352

0.5079869547558363

0.5072888873031864

0.5067113660796532

0.506520677080951

0.5059147658563949

0.5053652033316485

0.5047266888422277

0.5045473900598205

0.5041496504941453

0.5034630545828777

0.5025745763542784

0.5019910559468922

0.5012033102192196

0.5009489760675826

0.5008431925740259

0.5008297629370251

0.5008258245513862

0.5008137617093257

0.5008136785235711

0.5008130045533166

0.5008129888367148

0.5008129675120628

0.5008129469652479

0.5008129168191972

0.5008129132692991

0.5008129124596163

0.5008129124081014

0.500812912251931

0.5008129121356268

 

 

 lrModel.transform(testDF).select("features","rawPrediction","probability","prediction").show(30,false)

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|features                             |rawPrediction                               |probability                             |prediction|

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|[3.0182997164210517,-3.0182997164210517]    |[0.9533940335539883,0.04660596644601167]|0.0       |

|[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|[2.00632544907384,-2.00632544907384]        |[0.8814596114935873,0.11854038850641263]|0.0       |

|[0.0,27.0,1.5,0.0,2.0,16.0,6.0,5.0]  |[2.311142225292793,-2.311142225292793]      |[0.9097956387984996,0.09020436120150035]|0.0       |

|[0.0,27.0,4.0,1.0,3.0,18.0,4.0,5.0]  |[1.81918359677719,-1.81918359677719]        |[0.8604681362874618,0.13953186371253828]|0.0       |

|[0.0,27.0,7.0,1.0,2.0,18.0,1.0,5.0]  |[1.351091903842644,-1.351091903842644]      |[0.7943080837836515,0.20569191621634847]|0.0       |

|[0.0,27.0,7.0,1.0,3.0,16.0,1.0,4.0]  |[1.2482145486117338,-1.2482145486117338]    |[0.7769906379765039,0.2230093620234961] |0.0       |

|[0.0,27.0,10.0,1.0,2.0,12.0,1.0,4.0] |[0.6770360847975654,-0.6770360847975654]    |[0.6630768615308953,0.33692313846910465]|0.0       |

|[0.0,32.0,10.0,1.0,4.0,17.0,5.0,4.0] |[1.343039637398138,-1.343039637398138]      |[0.7929893642953615,0.20701063570463848]|0.0       |

|[0.0,32.0,10.0,1.0,5.0,14.0,4.0,5.0] |[2.220023246987134,-2.220023246987134]      |[0.9020332500408325,0.09796674995916752]|0.0       |

|[0.0,32.0,15.0,1.0,3.0,18.0,5.0,4.0] |[0.5532756896916551,-0.5532756896916551]    |[0.6348952415965647,0.3651047584034352] |0.0       |

|[0.0,37.0,15.0,1.0,4.0,17.0,1.0,5.0] |[1.7581459850319243,-1.7581459850319243]    |[0.8529773058286395,0.14702269417136052]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,9.0,5.0,5.0]  |[2.6088743974586124,-2.6088743974586124]    |[0.9314305415455806,0.06856945845441945]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,12.0,1.0,3.0] |[1.8410975503955256,-1.8410975503955256]    |[0.8630784610725231,0.13692153892747697]|0.0       |

|[0.0,57.0,15.0,1.0,4.0,16.0,6.0,4.0] |[1.904911346081691,-1.904911346081691]      |[0.8704463839526814,0.1295536160473186] |0.0       |

|[1.0,22.0,4.0,0.0,1.0,18.0,5.0,5.0]  |[1.2616839124674724,-1.2616839124674724]    |[0.7793158477292919,0.22068415227070803]|0.0       |

|[1.0,22.0,4.0,0.0,2.0,18.0,5.0,5.0]  |[1.6042279173740237,-1.6042279173740237]    |[0.832608465695705,0.16739153430429493] |0.0       |

|[1.0,27.0,4.0,1.0,3.0,16.0,5.0,5.0]  |[1.4818864529709268,-1.4818864529709268]    |[0.8148573492085158,0.1851426507914842] |0.0       |

|[1.0,27.0,4.0,1.0,4.0,14.0,5.0,4.0]  |[1.379009097740017,-1.379009097740017]      |[0.7988318098541624,0.2011681901458377] |0.0       |

|[1.0,32.0,0.125,1.0,2.0,18.0,5.0,2.0]|[0.28148664352576547,-0.28148664352576547]  |[0.569910656659749,0.430089343340251]   |0.0       |

|[1.0,32.0,10.0,1.0,2.0,20.0,6.0,3.0] |[-0.1851761257948623,0.1851761257948623]    |[0.45383780246566996,0.5461621975343299]|1.0       |

|[1.0,32.0,10.0,1.0,4.0,20.0,6.0,4.0] |[0.9625930297088949,-0.9625930297088949]    |[0.7236406723848533,0.2763593276151468] |0.0       |

|[1.0,32.0,15.0,1.0,1.0,16.0,5.0,5.0] |[0.039440462424945366,-0.039440462424945366]|[0.5098588376463971,0.4901411623536029] |0.0       |

|[1.0,37.0,4.0,1.0,1.0,18.0,5.0,4.0]  |[0.7319377705508958,-0.7319377705508958]    |[0.6752303588678488,0.3247696411321513] |0.0       |

|[1.0,37.0,15.0,1.0,5.0,20.0,5.0,4.0] |[1.119955894572572,-1.119955894572572]      |[0.7539805352533917,0.24601946474660835]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,17.0,6.0,5.0] |[1.4276540623429193,-1.4276540623429193]    |[0.8065355283195409,0.19346447168045908]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,4.0,5.0] |[1.4935019453371354,-1.4935019453371354]    |[0.8166033137058254,0.1833966862941747] |0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,6.0,3.0] |[0.4764020926318233,-0.4764020926318233]    |[0.6168979221749373,0.38310207782506256]|0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0] |[1.0201325344483316,-1.0201325344483316]    |[0.734998414766428,0.265001585233572]   |0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,7.0,2.0] |[-0.04283609891898266,0.04283609891898266]  |[0.48929261249695394,0.5107073875030461]|1.0       |

|[1.0,57.0,15.0,1.0,5.0,20.0,5.0,3.0] |[1.4874352661557535,-1.4874352661557535]    |[0.8156930079647114,0.18430699203528864]|0.0       |

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

only showing top 30 rows

本文转载自:http://www.cnblogs.com/wwxbi/p/6224670.html

共有 人打赏支持
hblt-j
粉丝 19
博文 147
码字总数 63579
作品 0
海淀
架构师
私信 提问
数据挖掘-逻辑Logistic回归

逻辑回归的基本过程:a建立回归或者分类模型--->b 建立代价函数 ---> c 优化方法迭代求出最优的模型参数 --->d 验证求解模型的好坏。 1.逻辑回归模型: 逻辑回归(Logistic Regression):既...

蜘蛛侠不会飞
07/19
0
0
【python数据挖掘课程】十六.逻辑回归LogisticRegression分析鸢尾花数据

今天是教师节,容我先感叹下。 祝天下所有老师教师节快乐,这是自己的第二个教师节,这一年来,无限感慨,有给一个人的指导,有给十几个人讲毕设,有几十人的实验,有上百人的课堂,也有给上...

Eastmount
2017/09/10
0
0
机器学习--第四讲-评估二元分类的简介

1.数据的介绍 在之前的任务中,我们学习了有关分类,逻辑回归,并且学习对于研究生入学申请的数据,来怎么使用scikit-learn 来拟合一个逻辑回归模型。我们将持续使用这个包含644 个申请人的数...

Betty__
2016/10/26
81
0
《统计学习方法》笔记(五)逻辑斯蒂回归与最大熵模型

LR回归(Logistic Regression) LR回归,虽然这个算法从名字上来看,是回归算法,但其实际上是一个分类算法。在机器学习算法中,有几十种分类器,LR回归是其中最常用的一个。 LR回归是在线性...

ch1209498273
05/23
0
0
sklearn调包侠之逻辑回归

算法原理 传送门:机器学习实战之Logistic回归 正则化 这里补充下正则化的知识。当一个模型太复杂时,就容易过拟合,解决的办法是减少输入特征的个数,或者获取更多的训练样本。正则化也是用...

罗罗攀
06/29
0
0

没有更多内容

加载失败,请刷新页面

加载更多

flutter Expanded用法

使用的地方:一个分类,类似京东的,左右两边都可以滑动 Widget build(BuildContext context) { return Row(children: [ Column( children: <Widget>[ Ex......

大灰狼wow
13分钟前
1
0
Java8 Map中新增的方法使用总结

前言 得益于 Java 8 的 default 方法特性,Java 8 对 Map 增加了不少实用的默认方法,像 getOrDefault, forEach, replace, replaceAll, putIfAbsent, remove(key, value), computeIfPresent,......

kaixin_code
22分钟前
1
0
@TransactionConfiguration

@TransactionConfiguration过时与替代写法 @TransactionConfiguration 替代写法

miaojiangmin
24分钟前
0
0
浅谈Vue响应式(数组变异方法)

很多初使用Vue的同学会发现,在改变数组的值的时候,值确实是改变了,但是视图却无动于衷,果然是因为数组太高冷了吗? 查看官方文档才发现,不是女神太高冷,而是你没用对方法。 看来想让女...

开元中国2015
26分钟前
1
0
Elasticsearch通关教程(五):如何通过SQL查询Elasticsearch

  这篇博文本来是想放在全系列的大概第五、六篇的时候再讲的,毕竟查询是在索引创建、索引文档数据生成和一些基本概念介绍完之后才需要的。当前面的一些知识概念全都讲解完之后再讲解查询是...

SEOwhywhy
45分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部