文档章节

LogisticRegression 逻辑回归之建模

hblt-j
 hblt-j
发布于 2017/08/29 14:11
字数 1164
阅读 142
收藏 0

#程序员薪资揭榜#你做程序员几年了?月薪多少?发量还在么?>>>

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression }

import org.apache.spark.ml.tuning.{ ParamGridBuilder, TrainValidationSplit }

val spark = SparkSession.builder().appName("Spark Logistic Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

// For implicit conversions like converting RDDs to DataFrames

import spark.implicits._

 

val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( 

      (0, "male", 37, 10, "no", 3, 18, 7, 4), 

      (0, "female", 27, 4, "no", 4, 14, 6, 4), 

 

sqlDF.show()

+-------+------+----+------------+--------+-------------+---------+----------+------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|

+-------+------+----+------------+--------+-------------+---------+----------+------+

|      0|     1|37.0|        10.0|       0|          3.0|     18.0|       7.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       6.0|   4.0|

|      0|     0|32.0|        15.0|       1|          1.0|     12.0|       1.0|   4.0|

|      0|     1|57.0|        15.0|       1|          5.0|     18.0|       6.0|   5.0|

|      0|     1|22.0|        0.75|       0|          2.0|     17.0|       6.0|   3.0|

|      0|     0|32.0|         1.5|       0|          2.0|     17.0|       5.0|   5.0|

|      0|     0|22.0|        0.75|       0|          2.0|     12.0|       1.0|   3.0|

|      0|     1|57.0|        15.0|       1|          2.0|     14.0|       4.0|   4.0|

|      0|     0|32.0|        15.0|       1|          4.0|     16.0|       1.0|   2.0|

|      0|     1|22.0|         1.5|       0|          4.0|     14.0|       4.0|   5.0|

|      0|     1|37.0|        15.0|       1|          2.0|     20.0|       7.0|   2.0|

|      0|     1|27.0|         4.0|       1|          4.0|     18.0|       6.0|   4.0|

|      0|     1|47.0|        15.0|       1|          5.0|     17.0|       6.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     17.0|       5.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       5.0|   4.0|

|      0|     0|37.0|        15.0|       1|          1.0|     17.0|       5.0|   5.0|

|      0|     0|37.0|        15.0|       1|          2.0|     18.0|       4.0|   3.0|

|      0|     0|22.0|        0.75|       0|          3.0|     16.0|       5.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     16.0|       5.0|   5.0|

|      0|     0|27.0|        10.0|       1|          2.0|     14.0|       1.0|   5.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+

only showing top 20 rows

 

val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

colArray2: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(sqlDF)

vecDF: org.apache.spark.sql.DataFrame = [affairs: int, gender: int ... 8 more fields]

 

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

 

val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)

lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9d8a91cb1a0b

 

 

// 输出逻辑回归的系数和截距

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [0.308688148697453,-0.04150802586369178,0.08771801000466706,0.6896853841812993,-0.3425440049065515,0.008629892776596084,0.0458687806620022,-0.46268114569065383] Intercept: 1.263

200227888706

  

 

// 设置ElasticNet混合参数,范围为[0,1]。

// 对于α= 0,惩罚是L2惩罚。 对于alpha = 1,它是一个L1惩罚。 对于0 <α<1,惩罚是L1和L2的组合。 默认值为0.0,这是一个L2惩罚。

lrModel.getElasticNetParam

res5: Double = 0.0

 

lrModel.getRegParam  // 正则化参数>=0

res6: Double = 0.0

 

lrModel.getStandardization  // 在拟合模型之前,是否标准化特征

res7: Boolean = true

 

// 在二进制分类中设置阈值,范围为[0,1]。如果类标签1的估计概率>Threshold,则预测1,否则0.高阈值鼓励模型更频繁地预测0; 低阈值鼓励模型更频繁地预测1。默认值为0.5。

lrModel.getThreshold

res8: Double = 0.5

 

// 设置迭代的收敛容限。 较小的值将导致更高的精度与更多的迭代的成本。 默认值为1E-6。

lrModel.getTol

res9: Double = 1.0E-6

 

 

lrModel.transform(testDF).show

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|      0|     0|22.0|       0.125|       0|          4.0|     14.0|       4.0|   5.0|[0.0,22.0,0.125,0...|[3.01829971642105...|[0.95339403355398...|       0.0|

|      0|     0|22.0|       0.417|       1|          3.0|     14.0|       3.0|   5.0|[0.0,22.0,0.417,1...|[2.00632544907384...|[0.88145961149358...|       0.0|

|      0|     0|27.0|         1.5|       0|          2.0|     16.0|       6.0|   5.0|[0.0,27.0,1.5,0.0...|[2.31114222529279...|[0.90979563879849...|       0.0|

|      0|     0|27.0|         4.0|       1|          3.0|     18.0|       4.0|   5.0|[0.0,27.0,4.0,1.0...|[1.81918359677719...|[0.86046813628746...|       0.0|

|      0|     0|27.0|         7.0|       1|          2.0|     18.0|       1.0|   5.0|[0.0,27.0,7.0,1.0...|[1.35109190384264...|[0.79430808378365...|       0.0|

|      0|     0|27.0|         7.0|       1|          3.0|     16.0|       1.0|   4.0|[0.0,27.0,7.0,1.0...|[1.24821454861173...|[0.77699063797650...|       0.0|

|      0|     0|27.0|        10.0|       1|          2.0|     12.0|       1.0|   4.0|[0.0,27.0,10.0,1....|[0.67703608479756...|[0.66307686153089...|       0.0|

|      0|     0|32.0|        10.0|       1|          4.0|     17.0|       5.0|   4.0|[0.0,32.0,10.0,1....|[1.34303963739813...|[0.79298936429536...|       0.0|

|      0|     0|32.0|        10.0|       1|          5.0|     14.0|       4.0|   5.0|[0.0,32.0,10.0,1....|[2.22002324698713...|[0.90203325004083...|       0.0|

|      0|     0|32.0|        15.0|       1|          3.0|     18.0|       5.0|   4.0|[0.0,32.0,15.0,1....|[0.55327568969165...|[0.63489524159656...|       0.0|

|      0|     0|37.0|        15.0|       1|          4.0|     17.0|       1.0|   5.0|[0.0,37.0,15.0,1....|[1.75814598503192...|[0.85297730582863...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|      9.0|       5.0|   5.0|[0.0,52.0,15.0,1....|[2.60887439745861...|[0.93143054154558...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|     12.0|       1.0|   3.0|[0.0,52.0,15.0,1....|[1.84109755039552...|[0.86307846107252...|       0.0|

|      0|     0|57.0|        15.0|       1|          4.0|     16.0|       6.0|   4.0|[0.0,57.0,15.0,1....|[1.90491134608169...|[0.87044638395268...|       0.0|

|      0|     1|22.0|         4.0|       0|          1.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.26168391246747...|[0.77931584772929...|       0.0|

|      0|     1|22.0|         4.0|       0|          2.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.60422791737402...|[0.83260846569570...|       0.0|

|      0|     1|27.0|         4.0|       1|          3.0|     16.0|       5.0|   5.0|[1.0,27.0,4.0,1.0...|[1.48188645297092...|[0.81485734920851...|       0.0|

|      0|     1|27.0|         4.0|       1|          4.0|     14.0|       5.0|   4.0|[1.0,27.0,4.0,1.0...|[1.37900909774001...|[0.79883180985416...|       0.0|

|      0|     1|32.0|       0.125|       1|          2.0|     18.0|       5.0|   2.0|[1.0,32.0,0.125,1...|[0.28148664352576...|[0.56991065665974...|       0.0|

|      0|     1|32.0|        10.0|       1|          2.0|     20.0|       6.0|   3.0|[1.0,32.0,10.0,1....|[-0.1851761257948...|[0.45383780246566...|       1.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

only showing top 20 rows

 

 

 

// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier

// example

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@4cde233d

 

 

// Obtain the objective per iteration.

val objectiveHistory = trainingSummary.objectiveHistory

objectiveHistory: Array[Double] = Array(0.5613118243072733, 0.5564125149222438, 0.5365395467216898, 0.5160918427628939, 0.51304621799159, 0.5105231964507352, 0.5079869547558363, 0.50728888730

31864, 0.5067113660796532, 0.506520677080951, 0.5059147658563949, 0.5053652033316485, 0.5047266888422277, 0.5045473900598205, 0.5041496504941453, 0.5034630545828777, 0.5025745763542784, 0.5019910559468922, 0.5012033102192196, 0.5009489760675826, 0.5008431925740259, 0.5008297629370251, 0.5008258245513862, 0.5008137617093257, 0.5008136785235711, 0.5008130045533166, 0.5008129888367148, 0.5008129675120628, 0.5008129469652479, 0.5008129168191972, 0.5008129132692991, 0.5008129124596163, 0.5008129124081014, 0.500812912251931, 0.5008129121356268)

objectiveHistory.foreach(loss => println(loss))

0.5613118243072733

0.5564125149222438

0.5365395467216898

0.5160918427628939

0.51304621799159

0.5105231964507352

0.5079869547558363

0.5072888873031864

0.5067113660796532

0.506520677080951

0.5059147658563949

0.5053652033316485

0.5047266888422277

0.5045473900598205

0.5041496504941453

0.5034630545828777

0.5025745763542784

0.5019910559468922

0.5012033102192196

0.5009489760675826

0.5008431925740259

0.5008297629370251

0.5008258245513862

0.5008137617093257

0.5008136785235711

0.5008130045533166

0.5008129888367148

0.5008129675120628

0.5008129469652479

0.5008129168191972

0.5008129132692991

0.5008129124596163

0.5008129124081014

0.500812912251931

0.5008129121356268

 

 

 lrModel.transform(testDF).select("features","rawPrediction","probability","prediction").show(30,false)

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|features                             |rawPrediction                               |probability                             |prediction|

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|[3.0182997164210517,-3.0182997164210517]    |[0.9533940335539883,0.04660596644601167]|0.0       |

|[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|[2.00632544907384,-2.00632544907384]        |[0.8814596114935873,0.11854038850641263]|0.0       |

|[0.0,27.0,1.5,0.0,2.0,16.0,6.0,5.0]  |[2.311142225292793,-2.311142225292793]      |[0.9097956387984996,0.09020436120150035]|0.0       |

|[0.0,27.0,4.0,1.0,3.0,18.0,4.0,5.0]  |[1.81918359677719,-1.81918359677719]        |[0.8604681362874618,0.13953186371253828]|0.0       |

|[0.0,27.0,7.0,1.0,2.0,18.0,1.0,5.0]  |[1.351091903842644,-1.351091903842644]      |[0.7943080837836515,0.20569191621634847]|0.0       |

|[0.0,27.0,7.0,1.0,3.0,16.0,1.0,4.0]  |[1.2482145486117338,-1.2482145486117338]    |[0.7769906379765039,0.2230093620234961] |0.0       |

|[0.0,27.0,10.0,1.0,2.0,12.0,1.0,4.0] |[0.6770360847975654,-0.6770360847975654]    |[0.6630768615308953,0.33692313846910465]|0.0       |

|[0.0,32.0,10.0,1.0,4.0,17.0,5.0,4.0] |[1.343039637398138,-1.343039637398138]      |[0.7929893642953615,0.20701063570463848]|0.0       |

|[0.0,32.0,10.0,1.0,5.0,14.0,4.0,5.0] |[2.220023246987134,-2.220023246987134]      |[0.9020332500408325,0.09796674995916752]|0.0       |

|[0.0,32.0,15.0,1.0,3.0,18.0,5.0,4.0] |[0.5532756896916551,-0.5532756896916551]    |[0.6348952415965647,0.3651047584034352] |0.0       |

|[0.0,37.0,15.0,1.0,4.0,17.0,1.0,5.0] |[1.7581459850319243,-1.7581459850319243]    |[0.8529773058286395,0.14702269417136052]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,9.0,5.0,5.0]  |[2.6088743974586124,-2.6088743974586124]    |[0.9314305415455806,0.06856945845441945]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,12.0,1.0,3.0] |[1.8410975503955256,-1.8410975503955256]    |[0.8630784610725231,0.13692153892747697]|0.0       |

|[0.0,57.0,15.0,1.0,4.0,16.0,6.0,4.0] |[1.904911346081691,-1.904911346081691]      |[0.8704463839526814,0.1295536160473186] |0.0       |

|[1.0,22.0,4.0,0.0,1.0,18.0,5.0,5.0]  |[1.2616839124674724,-1.2616839124674724]    |[0.7793158477292919,0.22068415227070803]|0.0       |

|[1.0,22.0,4.0,0.0,2.0,18.0,5.0,5.0]  |[1.6042279173740237,-1.6042279173740237]    |[0.832608465695705,0.16739153430429493] |0.0       |

|[1.0,27.0,4.0,1.0,3.0,16.0,5.0,5.0]  |[1.4818864529709268,-1.4818864529709268]    |[0.8148573492085158,0.1851426507914842] |0.0       |

|[1.0,27.0,4.0,1.0,4.0,14.0,5.0,4.0]  |[1.379009097740017,-1.379009097740017]      |[0.7988318098541624,0.2011681901458377] |0.0       |

|[1.0,32.0,0.125,1.0,2.0,18.0,5.0,2.0]|[0.28148664352576547,-0.28148664352576547]  |[0.569910656659749,0.430089343340251]   |0.0       |

|[1.0,32.0,10.0,1.0,2.0,20.0,6.0,3.0] |[-0.1851761257948623,0.1851761257948623]    |[0.45383780246566996,0.5461621975343299]|1.0       |

|[1.0,32.0,10.0,1.0,4.0,20.0,6.0,4.0] |[0.9625930297088949,-0.9625930297088949]    |[0.7236406723848533,0.2763593276151468] |0.0       |

|[1.0,32.0,15.0,1.0,1.0,16.0,5.0,5.0] |[0.039440462424945366,-0.039440462424945366]|[0.5098588376463971,0.4901411623536029] |0.0       |

|[1.0,37.0,4.0,1.0,1.0,18.0,5.0,4.0]  |[0.7319377705508958,-0.7319377705508958]    |[0.6752303588678488,0.3247696411321513] |0.0       |

|[1.0,37.0,15.0,1.0,5.0,20.0,5.0,4.0] |[1.119955894572572,-1.119955894572572]      |[0.7539805352533917,0.24601946474660835]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,17.0,6.0,5.0] |[1.4276540623429193,-1.4276540623429193]    |[0.8065355283195409,0.19346447168045908]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,4.0,5.0] |[1.4935019453371354,-1.4935019453371354]    |[0.8166033137058254,0.1833966862941747] |0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,6.0,3.0] |[0.4764020926318233,-0.4764020926318233]    |[0.6168979221749373,0.38310207782506256]|0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0] |[1.0201325344483316,-1.0201325344483316]    |[0.734998414766428,0.265001585233572]   |0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,7.0,2.0] |[-0.04283609891898266,0.04283609891898266]  |[0.48929261249695394,0.5107073875030461]|1.0       |

|[1.0,57.0,15.0,1.0,5.0,20.0,5.0,3.0] |[1.4874352661557535,-1.4874352661557535]    |[0.8156930079647114,0.18430699203528864]|0.0       |

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

only showing top 30 rows

本文转载自:http://www.cnblogs.com/wwxbi/p/6224670.html

hblt-j
粉丝 24
博文 218
码字总数 73000
作品 0
海淀
架构师
私信 提问
加载中

评论(0)

Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression

Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,...

osc_w5iew3dd
2018/04/16
5
0
数据挖掘建模-Logistic回归

逻辑回归的基本过程:a建立回归或者分类模型--->b 建立代价函数 ---> c 优化方法迭代求出最优的模型参数 --->d 验证求解模型的好坏。 1.逻辑回归模型: 逻辑回归(Logistic Regression):基...

蜘蛛侠不会飞
04/23
0
0
sklearn逻辑回归(Logistic Regression,LR)调参指南

python信用评分卡建模(附代码,博主录制) https://study.163.com/course/introduction.htm?courseId=1005214003&utmcampaign=commission&utmsource=cp-400000000398149&utm_medium=share s......

osc_3o8lxtf4
2019/11/02
1
0
机器学习:逻辑回归(OvR 与 OvO)

一、基础理解 问题:逻辑回归算法是用回归的方式解决分类的问题,而且只可以解决二分类问题; 方案:可以通过改造,使得逻辑回归算法可以解决多分类问题; 改造方法: OvR(One vs Rest),一...

osc_pmwdk963
2018/07/30
3
0
sklearn—LinearRegression,Ridge,RidgeCV,Lasso线性回归模型简单使用

线性回归 import sklearn from sklearn.linear_model import LinearRegressionX= [[0, 0], [1, 2], [2, 4]]y = [0, 1, 2]clf = LinearRegression() #fit_intercept=True #默认值为 True,表示......

osc_t4d5tw3o
2018/03/12
3
0

没有更多内容

加载失败,请刷新页面

加载更多

ThreadLocal

一、ThreadLocal简介   多线程访问同一个共享变量的时候容易出现并发问题,特别是多个线程对一个变量进行写入的时候,为了保证线程安全,一般使用者在访问共享变量的时候需要进行额外的同步...

architect刘源源
14分钟前
9
0
微信小程序客服会话卡片、自定义客服消息卡片

一、微信客服会话启用会话卡片 1. open-type="contact" 2. show-message-card =true 更多参考官方文档: https://developers.weixin.qq.com/miniprogram/dev/component/button.html 当前效果......

tianma3798
29分钟前
6
0
练习Linux常用命令

练习命令 Linux常用命令 Linux中一切皆文件,没有消息就是最好的消息 以下所有命令以centos7为基础, 网络相关配置 测试外网是否连通 安装网卡测试工具,即ifconfig程序 查看网卡 临时修改I...

千年典韦
30分钟前
10
0
从poison社网站爬取历代作品资料

使用的语言是python,爬取使用的代码包在我的主页有提供. 其中一些相关的数据设定如下(复制为data.py,然后运行主页提供的包的main.py): from mypython import *CODE = '4fjl_fjiepq24x' #...

setycyas
39分钟前
26
0
确定已安装的PowerShell版本 - Determine installed PowerShell version

问题: 如何确定计算机上安装了哪个版本的PowerShell,以及是否确实安装了该版本? 解决方案: 参考一: https://stackoom.com/question/7euv/确定已安装的PowerShell版本 参考二: https://...

技术盛宴
41分钟前
24
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部