文档章节

LogisticRegression 逻辑回归之建模

hblt-j
 hblt-j
发布于 2017/08/29 14:11
字数 1164
阅读 15
收藏 0

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression }

import org.apache.spark.ml.tuning.{ ParamGridBuilder, TrainValidationSplit }

val spark = SparkSession.builder().appName("Spark Logistic Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

// For implicit conversions like converting RDDs to DataFrames

import spark.implicits._

 

val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( 

      (0, "male", 37, 10, "no", 3, 18, 7, 4), 

      (0, "female", 27, 4, "no", 4, 14, 6, 4), 

 

sqlDF.show()

+-------+------+----+------------+--------+-------------+---------+----------+------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|

+-------+------+----+------------+--------+-------------+---------+----------+------+

|      0|     1|37.0|        10.0|       0|          3.0|     18.0|       7.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       6.0|   4.0|

|      0|     0|32.0|        15.0|       1|          1.0|     12.0|       1.0|   4.0|

|      0|     1|57.0|        15.0|       1|          5.0|     18.0|       6.0|   5.0|

|      0|     1|22.0|        0.75|       0|          2.0|     17.0|       6.0|   3.0|

|      0|     0|32.0|         1.5|       0|          2.0|     17.0|       5.0|   5.0|

|      0|     0|22.0|        0.75|       0|          2.0|     12.0|       1.0|   3.0|

|      0|     1|57.0|        15.0|       1|          2.0|     14.0|       4.0|   4.0|

|      0|     0|32.0|        15.0|       1|          4.0|     16.0|       1.0|   2.0|

|      0|     1|22.0|         1.5|       0|          4.0|     14.0|       4.0|   5.0|

|      0|     1|37.0|        15.0|       1|          2.0|     20.0|       7.0|   2.0|

|      0|     1|27.0|         4.0|       1|          4.0|     18.0|       6.0|   4.0|

|      0|     1|47.0|        15.0|       1|          5.0|     17.0|       6.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     17.0|       5.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       5.0|   4.0|

|      0|     0|37.0|        15.0|       1|          1.0|     17.0|       5.0|   5.0|

|      0|     0|37.0|        15.0|       1|          2.0|     18.0|       4.0|   3.0|

|      0|     0|22.0|        0.75|       0|          3.0|     16.0|       5.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     16.0|       5.0|   5.0|

|      0|     0|27.0|        10.0|       1|          2.0|     14.0|       1.0|   5.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+

only showing top 20 rows

 

val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

colArray2: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(sqlDF)

vecDF: org.apache.spark.sql.DataFrame = [affairs: int, gender: int ... 8 more fields]

 

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

 

val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)

lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9d8a91cb1a0b

 

 

// 输出逻辑回归的系数和截距

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [0.308688148697453,-0.04150802586369178,0.08771801000466706,0.6896853841812993,-0.3425440049065515,0.008629892776596084,0.0458687806620022,-0.46268114569065383] Intercept: 1.263

200227888706

  

 

// 设置ElasticNet混合参数,范围为[0,1]。

// 对于α= 0,惩罚是L2惩罚。 对于alpha = 1,它是一个L1惩罚。 对于0 <α<1,惩罚是L1和L2的组合。 默认值为0.0,这是一个L2惩罚。

lrModel.getElasticNetParam

res5: Double = 0.0

 

lrModel.getRegParam  // 正则化参数>=0

res6: Double = 0.0

 

lrModel.getStandardization  // 在拟合模型之前,是否标准化特征

res7: Boolean = true

 

// 在二进制分类中设置阈值,范围为[0,1]。如果类标签1的估计概率>Threshold,则预测1,否则0.高阈值鼓励模型更频繁地预测0; 低阈值鼓励模型更频繁地预测1。默认值为0.5。

lrModel.getThreshold

res8: Double = 0.5

 

// 设置迭代的收敛容限。 较小的值将导致更高的精度与更多的迭代的成本。 默认值为1E-6。

lrModel.getTol

res9: Double = 1.0E-6

 

 

lrModel.transform(testDF).show

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|      0|     0|22.0|       0.125|       0|          4.0|     14.0|       4.0|   5.0|[0.0,22.0,0.125,0...|[3.01829971642105...|[0.95339403355398...|       0.0|

|      0|     0|22.0|       0.417|       1|          3.0|     14.0|       3.0|   5.0|[0.0,22.0,0.417,1...|[2.00632544907384...|[0.88145961149358...|       0.0|

|      0|     0|27.0|         1.5|       0|          2.0|     16.0|       6.0|   5.0|[0.0,27.0,1.5,0.0...|[2.31114222529279...|[0.90979563879849...|       0.0|

|      0|     0|27.0|         4.0|       1|          3.0|     18.0|       4.0|   5.0|[0.0,27.0,4.0,1.0...|[1.81918359677719...|[0.86046813628746...|       0.0|

|      0|     0|27.0|         7.0|       1|          2.0|     18.0|       1.0|   5.0|[0.0,27.0,7.0,1.0...|[1.35109190384264...|[0.79430808378365...|       0.0|

|      0|     0|27.0|         7.0|       1|          3.0|     16.0|       1.0|   4.0|[0.0,27.0,7.0,1.0...|[1.24821454861173...|[0.77699063797650...|       0.0|

|      0|     0|27.0|        10.0|       1|          2.0|     12.0|       1.0|   4.0|[0.0,27.0,10.0,1....|[0.67703608479756...|[0.66307686153089...|       0.0|

|      0|     0|32.0|        10.0|       1|          4.0|     17.0|       5.0|   4.0|[0.0,32.0,10.0,1....|[1.34303963739813...|[0.79298936429536...|       0.0|

|      0|     0|32.0|        10.0|       1|          5.0|     14.0|       4.0|   5.0|[0.0,32.0,10.0,1....|[2.22002324698713...|[0.90203325004083...|       0.0|

|      0|     0|32.0|        15.0|       1|          3.0|     18.0|       5.0|   4.0|[0.0,32.0,15.0,1....|[0.55327568969165...|[0.63489524159656...|       0.0|

|      0|     0|37.0|        15.0|       1|          4.0|     17.0|       1.0|   5.0|[0.0,37.0,15.0,1....|[1.75814598503192...|[0.85297730582863...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|      9.0|       5.0|   5.0|[0.0,52.0,15.0,1....|[2.60887439745861...|[0.93143054154558...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|     12.0|       1.0|   3.0|[0.0,52.0,15.0,1....|[1.84109755039552...|[0.86307846107252...|       0.0|

|      0|     0|57.0|        15.0|       1|          4.0|     16.0|       6.0|   4.0|[0.0,57.0,15.0,1....|[1.90491134608169...|[0.87044638395268...|       0.0|

|      0|     1|22.0|         4.0|       0|          1.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.26168391246747...|[0.77931584772929...|       0.0|

|      0|     1|22.0|         4.0|       0|          2.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.60422791737402...|[0.83260846569570...|       0.0|

|      0|     1|27.0|         4.0|       1|          3.0|     16.0|       5.0|   5.0|[1.0,27.0,4.0,1.0...|[1.48188645297092...|[0.81485734920851...|       0.0|

|      0|     1|27.0|         4.0|       1|          4.0|     14.0|       5.0|   4.0|[1.0,27.0,4.0,1.0...|[1.37900909774001...|[0.79883180985416...|       0.0|

|      0|     1|32.0|       0.125|       1|          2.0|     18.0|       5.0|   2.0|[1.0,32.0,0.125,1...|[0.28148664352576...|[0.56991065665974...|       0.0|

|      0|     1|32.0|        10.0|       1|          2.0|     20.0|       6.0|   3.0|[1.0,32.0,10.0,1....|[-0.1851761257948...|[0.45383780246566...|       1.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

only showing top 20 rows

 

 

 

// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier

// example

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@4cde233d

 

 

// Obtain the objective per iteration.

val objectiveHistory = trainingSummary.objectiveHistory

objectiveHistory: Array[Double] = Array(0.5613118243072733, 0.5564125149222438, 0.5365395467216898, 0.5160918427628939, 0.51304621799159, 0.5105231964507352, 0.5079869547558363, 0.50728888730

31864, 0.5067113660796532, 0.506520677080951, 0.5059147658563949, 0.5053652033316485, 0.5047266888422277, 0.5045473900598205, 0.5041496504941453, 0.5034630545828777, 0.5025745763542784, 0.5019910559468922, 0.5012033102192196, 0.5009489760675826, 0.5008431925740259, 0.5008297629370251, 0.5008258245513862, 0.5008137617093257, 0.5008136785235711, 0.5008130045533166, 0.5008129888367148, 0.5008129675120628, 0.5008129469652479, 0.5008129168191972, 0.5008129132692991, 0.5008129124596163, 0.5008129124081014, 0.500812912251931, 0.5008129121356268)

objectiveHistory.foreach(loss => println(loss))

0.5613118243072733

0.5564125149222438

0.5365395467216898

0.5160918427628939

0.51304621799159

0.5105231964507352

0.5079869547558363

0.5072888873031864

0.5067113660796532

0.506520677080951

0.5059147658563949

0.5053652033316485

0.5047266888422277

0.5045473900598205

0.5041496504941453

0.5034630545828777

0.5025745763542784

0.5019910559468922

0.5012033102192196

0.5009489760675826

0.5008431925740259

0.5008297629370251

0.5008258245513862

0.5008137617093257

0.5008136785235711

0.5008130045533166

0.5008129888367148

0.5008129675120628

0.5008129469652479

0.5008129168191972

0.5008129132692991

0.5008129124596163

0.5008129124081014

0.500812912251931

0.5008129121356268

 

 

 lrModel.transform(testDF).select("features","rawPrediction","probability","prediction").show(30,false)

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|features                             |rawPrediction                               |probability                             |prediction|

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|[3.0182997164210517,-3.0182997164210517]    |[0.9533940335539883,0.04660596644601167]|0.0       |

|[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|[2.00632544907384,-2.00632544907384]        |[0.8814596114935873,0.11854038850641263]|0.0       |

|[0.0,27.0,1.5,0.0,2.0,16.0,6.0,5.0]  |[2.311142225292793,-2.311142225292793]      |[0.9097956387984996,0.09020436120150035]|0.0       |

|[0.0,27.0,4.0,1.0,3.0,18.0,4.0,5.0]  |[1.81918359677719,-1.81918359677719]        |[0.8604681362874618,0.13953186371253828]|0.0       |

|[0.0,27.0,7.0,1.0,2.0,18.0,1.0,5.0]  |[1.351091903842644,-1.351091903842644]      |[0.7943080837836515,0.20569191621634847]|0.0       |

|[0.0,27.0,7.0,1.0,3.0,16.0,1.0,4.0]  |[1.2482145486117338,-1.2482145486117338]    |[0.7769906379765039,0.2230093620234961] |0.0       |

|[0.0,27.0,10.0,1.0,2.0,12.0,1.0,4.0] |[0.6770360847975654,-0.6770360847975654]    |[0.6630768615308953,0.33692313846910465]|0.0       |

|[0.0,32.0,10.0,1.0,4.0,17.0,5.0,4.0] |[1.343039637398138,-1.343039637398138]      |[0.7929893642953615,0.20701063570463848]|0.0       |

|[0.0,32.0,10.0,1.0,5.0,14.0,4.0,5.0] |[2.220023246987134,-2.220023246987134]      |[0.9020332500408325,0.09796674995916752]|0.0       |

|[0.0,32.0,15.0,1.0,3.0,18.0,5.0,4.0] |[0.5532756896916551,-0.5532756896916551]    |[0.6348952415965647,0.3651047584034352] |0.0       |

|[0.0,37.0,15.0,1.0,4.0,17.0,1.0,5.0] |[1.7581459850319243,-1.7581459850319243]    |[0.8529773058286395,0.14702269417136052]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,9.0,5.0,5.0]  |[2.6088743974586124,-2.6088743974586124]    |[0.9314305415455806,0.06856945845441945]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,12.0,1.0,3.0] |[1.8410975503955256,-1.8410975503955256]    |[0.8630784610725231,0.13692153892747697]|0.0       |

|[0.0,57.0,15.0,1.0,4.0,16.0,6.0,4.0] |[1.904911346081691,-1.904911346081691]      |[0.8704463839526814,0.1295536160473186] |0.0       |

|[1.0,22.0,4.0,0.0,1.0,18.0,5.0,5.0]  |[1.2616839124674724,-1.2616839124674724]    |[0.7793158477292919,0.22068415227070803]|0.0       |

|[1.0,22.0,4.0,0.0,2.0,18.0,5.0,5.0]  |[1.6042279173740237,-1.6042279173740237]    |[0.832608465695705,0.16739153430429493] |0.0       |

|[1.0,27.0,4.0,1.0,3.0,16.0,5.0,5.0]  |[1.4818864529709268,-1.4818864529709268]    |[0.8148573492085158,0.1851426507914842] |0.0       |

|[1.0,27.0,4.0,1.0,4.0,14.0,5.0,4.0]  |[1.379009097740017,-1.379009097740017]      |[0.7988318098541624,0.2011681901458377] |0.0       |

|[1.0,32.0,0.125,1.0,2.0,18.0,5.0,2.0]|[0.28148664352576547,-0.28148664352576547]  |[0.569910656659749,0.430089343340251]   |0.0       |

|[1.0,32.0,10.0,1.0,2.0,20.0,6.0,3.0] |[-0.1851761257948623,0.1851761257948623]    |[0.45383780246566996,0.5461621975343299]|1.0       |

|[1.0,32.0,10.0,1.0,4.0,20.0,6.0,4.0] |[0.9625930297088949,-0.9625930297088949]    |[0.7236406723848533,0.2763593276151468] |0.0       |

|[1.0,32.0,15.0,1.0,1.0,16.0,5.0,5.0] |[0.039440462424945366,-0.039440462424945366]|[0.5098588376463971,0.4901411623536029] |0.0       |

|[1.0,37.0,4.0,1.0,1.0,18.0,5.0,4.0]  |[0.7319377705508958,-0.7319377705508958]    |[0.6752303588678488,0.3247696411321513] |0.0       |

|[1.0,37.0,15.0,1.0,5.0,20.0,5.0,4.0] |[1.119955894572572,-1.119955894572572]      |[0.7539805352533917,0.24601946474660835]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,17.0,6.0,5.0] |[1.4276540623429193,-1.4276540623429193]    |[0.8065355283195409,0.19346447168045908]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,4.0,5.0] |[1.4935019453371354,-1.4935019453371354]    |[0.8166033137058254,0.1833966862941747] |0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,6.0,3.0] |[0.4764020926318233,-0.4764020926318233]    |[0.6168979221749373,0.38310207782506256]|0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0] |[1.0201325344483316,-1.0201325344483316]    |[0.734998414766428,0.265001585233572]   |0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,7.0,2.0] |[-0.04283609891898266,0.04283609891898266]  |[0.48929261249695394,0.5107073875030461]|1.0       |

|[1.0,57.0,15.0,1.0,5.0,20.0,5.0,3.0] |[1.4874352661557535,-1.4874352661557535]    |[0.8156930079647114,0.18430699203528864]|0.0       |

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

only showing top 30 rows

本文转载自:http://www.cnblogs.com/wwxbi/p/6224670.html

共有 人打赏支持
hblt-j
粉丝 16
博文 116
码字总数 56931
作品 0
海淀
架构师
数据挖掘-逻辑Logistic回归

逻辑回归的基本过程:a建立回归或者分类模型--->b 建立代价函数 ---> c 优化方法迭代求出最优的模型参数 --->d 验证求解模型的好坏。 1.逻辑回归模型: 逻辑回归(Logistic Regression):既...

蜘蛛侠不会飞
07/19
0
0
【python数据挖掘课程】十六.逻辑回归LogisticRegression分析鸢尾花数据

今天是教师节,容我先感叹下。 祝天下所有老师教师节快乐,这是自己的第二个教师节,这一年来,无限感慨,有给一个人的指导,有给十几个人讲毕设,有几十人的实验,有上百人的课堂,也有给上...

Eastmount
2017/09/10
0
0
《统计学习方法》笔记(五)逻辑斯蒂回归与最大熵模型

LR回归(Logistic Regression) LR回归,虽然这个算法从名字上来看,是回归算法,但其实际上是一个分类算法。在机器学习算法中,有几十种分类器,LR回归是其中最常用的一个。 LR回归是在线性...

ch1209498273
05/23
0
0
机器学习--第四讲-评估二元分类的简介

1.数据的介绍 在之前的任务中,我们学习了有关分类,逻辑回归,并且学习对于研究生入学申请的数据,来怎么使用scikit-learn 来拟合一个逻辑回归模型。我们将持续使用这个包含644 个申请人的数...

Betty__
2016/10/26
81
0
sklearn调包侠之逻辑回归

算法原理 传送门:机器学习实战之Logistic回归 正则化 这里补充下正则化的知识。当一个模型太复杂时,就容易过拟合,解决的办法是减少输入特征的个数,或者获取更多的训练样本。正则化也是用...

罗罗攀
06/29
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

MySQL 乱七八糟的可重复读隔离级别实现

MySQL 乱七八糟的可重复读隔离级别实现 摘要: 原文可阅读 http://www.iocoder.cn/Fight/MySQL-messy-implementation-of-repeatable-read-isolation-levels 「shimohq」欢迎转载,保留摘要,谢...

DemonsI
45分钟前
2
0
Spring源码阅读——2

在阅读源码之前,先了解下Spring的整体架构: 1、Spring的整体架构 1. Ioc(控制反转) Spring核心模块实现了Ioc的功能,它将类与类之间的依赖从代码中脱离出来,用配置的方式进行依赖关系描...

叶枫啦啦
今天
1
0
jQuery.post() 函数格式详解

jquery的Post方法$.post() $.post是jquery自带的一个方法,使用前需要引入jquery.js 语法:$.post(url,data,callback,type); url(必须):发送请求的地址,String类型 data(可选):发送给后台的...

森火
今天
0
0
referer是什么意思?

看看下面这个回答(打不开网页可以把网址复制到搜索栏): https://zhidao.baidu.com/question/577842068.html

杉下
今天
1
0
使用U盘安装CentOS-解决U盘找不到源

1. 使用UltraISO制作CentOS安装盘 如果需要安装带界面的系统,为保证安装顺利,可选择Everything版本的ISO制作安装盘。 2. 在BIOS中选择使用U盘安装 系统启动后,进入安装选择界面,其中有三...

Houor
今天
1
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部