LogisticRegression 逻辑回归之建模
博客专区 > hblt-j 的博客 > 博客详情
LogisticRegression 逻辑回归之建模
hblt-j 发表于4个月前
LogisticRegression 逻辑回归之建模
  • 发表于 4个月前
  • 阅读 7
  • 收藏 0
  • 点赞 0
  • 评论 0

腾讯云 技术升级10大核心产品年终让利>>>   

import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.Dataset

import org.apache.spark.sql.Row

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.Column

import org.apache.spark.sql.DataFrameReader

import org.apache.spark.rdd.RDD

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import org.apache.spark.sql.Encoder

import org.apache.spark.sql.DataFrameStatFunctions

import org.apache.spark.sql.functions._

 

import org.apache.spark.ml.linalg.Vectors

import org.apache.spark.ml.feature.VectorAssembler

import org.apache.spark.ml.Pipeline

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

import org.apache.spark.ml.classification.LogisticRegression

import org.apache.spark.ml.classification.{ BinaryLogisticRegressionSummary, LogisticRegression }

import org.apache.spark.ml.tuning.{ ParamGridBuilder, TrainValidationSplit }

val spark = SparkSession.builder().appName("Spark Logistic Regression").config("spark.some.config.option", "some-value").getOrCreate()

 

// For implicit conversions like converting RDDs to DataFrames

import spark.implicits._

 

val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( 

      (0, "male", 37, 10, "no", 3, 18, 7, 4), 

      (0, "female", 27, 4, "no", 4, 14, 6, 4), 

 

sqlDF.show()

+-------+------+----+------------+--------+-------------+---------+----------+------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|

+-------+------+----+------------+--------+-------------+---------+----------+------+

|      0|     1|37.0|        10.0|       0|          3.0|     18.0|       7.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       6.0|   4.0|

|      0|     0|32.0|        15.0|       1|          1.0|     12.0|       1.0|   4.0|

|      0|     1|57.0|        15.0|       1|          5.0|     18.0|       6.0|   5.0|

|      0|     1|22.0|        0.75|       0|          2.0|     17.0|       6.0|   3.0|

|      0|     0|32.0|         1.5|       0|          2.0|     17.0|       5.0|   5.0|

|      0|     0|22.0|        0.75|       0|          2.0|     12.0|       1.0|   3.0|

|      0|     1|57.0|        15.0|       1|          2.0|     14.0|       4.0|   4.0|

|      0|     0|32.0|        15.0|       1|          4.0|     16.0|       1.0|   2.0|

|      0|     1|22.0|         1.5|       0|          4.0|     14.0|       4.0|   5.0|

|      0|     1|37.0|        15.0|       1|          2.0|     20.0|       7.0|   2.0|

|      0|     1|27.0|         4.0|       1|          4.0|     18.0|       6.0|   4.0|

|      0|     1|47.0|        15.0|       1|          5.0|     17.0|       6.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     17.0|       5.0|   4.0|

|      0|     0|27.0|         4.0|       0|          4.0|     14.0|       5.0|   4.0|

|      0|     0|37.0|        15.0|       1|          1.0|     17.0|       5.0|   5.0|

|      0|     0|37.0|        15.0|       1|          2.0|     18.0|       4.0|   3.0|

|      0|     0|22.0|        0.75|       0|          3.0|     16.0|       5.0|   4.0|

|      0|     0|22.0|         1.5|       0|          2.0|     16.0|       5.0|   5.0|

|      0|     0|27.0|        10.0|       1|          2.0|     14.0|       1.0|   5.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+

only showing top 20 rows

 

val colArray2 = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

colArray2: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)

 

val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(sqlDF)

vecDF: org.apache.spark.sql.DataFrame = [affairs: int, gender: int ... 8 more fields]

 

 

val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)

trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [affairs: int, gender: int ... 8 more fields]

 

val lrModel = new LogisticRegression().setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)

lrModel: org.apache.spark.ml.classification.LogisticRegressionModel = logreg_9d8a91cb1a0b

 

 

// 输出逻辑回归的系数和截距

println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

Coefficients: [0.308688148697453,-0.04150802586369178,0.08771801000466706,0.6896853841812993,-0.3425440049065515,0.008629892776596084,0.0458687806620022,-0.46268114569065383] Intercept: 1.263

200227888706

  

 

// 设置ElasticNet混合参数,范围为[0,1]。

// 对于α= 0,惩罚是L2惩罚。 对于alpha = 1,它是一个L1惩罚。 对于0 <α<1,惩罚是L1和L2的组合。 默认值为0.0,这是一个L2惩罚。

lrModel.getElasticNetParam

res5: Double = 0.0

 

lrModel.getRegParam  // 正则化参数>=0

res6: Double = 0.0

 

lrModel.getStandardization  // 在拟合模型之前,是否标准化特征

res7: Boolean = true

 

// 在二进制分类中设置阈值,范围为[0,1]。如果类标签1的估计概率>Threshold,则预测1,否则0.高阈值鼓励模型更频繁地预测0; 低阈值鼓励模型更频繁地预测1。默认值为0.5。

lrModel.getThreshold

res8: Double = 0.5

 

// 设置迭代的收敛容限。 较小的值将导致更高的精度与更多的迭代的成本。 默认值为1E-6。

lrModel.getTol

res9: Double = 1.0E-6

 

 

lrModel.transform(testDF).show

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|affairs|gender| age|yearsmarried|children|religiousness|education|occupation|rating|            features|       rawPrediction|         probability|prediction|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

|      0|     0|22.0|       0.125|       0|          4.0|     14.0|       4.0|   5.0|[0.0,22.0,0.125,0...|[3.01829971642105...|[0.95339403355398...|       0.0|

|      0|     0|22.0|       0.417|       1|          3.0|     14.0|       3.0|   5.0|[0.0,22.0,0.417,1...|[2.00632544907384...|[0.88145961149358...|       0.0|

|      0|     0|27.0|         1.5|       0|          2.0|     16.0|       6.0|   5.0|[0.0,27.0,1.5,0.0...|[2.31114222529279...|[0.90979563879849...|       0.0|

|      0|     0|27.0|         4.0|       1|          3.0|     18.0|       4.0|   5.0|[0.0,27.0,4.0,1.0...|[1.81918359677719...|[0.86046813628746...|       0.0|

|      0|     0|27.0|         7.0|       1|          2.0|     18.0|       1.0|   5.0|[0.0,27.0,7.0,1.0...|[1.35109190384264...|[0.79430808378365...|       0.0|

|      0|     0|27.0|         7.0|       1|          3.0|     16.0|       1.0|   4.0|[0.0,27.0,7.0,1.0...|[1.24821454861173...|[0.77699063797650...|       0.0|

|      0|     0|27.0|        10.0|       1|          2.0|     12.0|       1.0|   4.0|[0.0,27.0,10.0,1....|[0.67703608479756...|[0.66307686153089...|       0.0|

|      0|     0|32.0|        10.0|       1|          4.0|     17.0|       5.0|   4.0|[0.0,32.0,10.0,1....|[1.34303963739813...|[0.79298936429536...|       0.0|

|      0|     0|32.0|        10.0|       1|          5.0|     14.0|       4.0|   5.0|[0.0,32.0,10.0,1....|[2.22002324698713...|[0.90203325004083...|       0.0|

|      0|     0|32.0|        15.0|       1|          3.0|     18.0|       5.0|   4.0|[0.0,32.0,15.0,1....|[0.55327568969165...|[0.63489524159656...|       0.0|

|      0|     0|37.0|        15.0|       1|          4.0|     17.0|       1.0|   5.0|[0.0,37.0,15.0,1....|[1.75814598503192...|[0.85297730582863...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|      9.0|       5.0|   5.0|[0.0,52.0,15.0,1....|[2.60887439745861...|[0.93143054154558...|       0.0|

|      0|     0|52.0|        15.0|       1|          5.0|     12.0|       1.0|   3.0|[0.0,52.0,15.0,1....|[1.84109755039552...|[0.86307846107252...|       0.0|

|      0|     0|57.0|        15.0|       1|          4.0|     16.0|       6.0|   4.0|[0.0,57.0,15.0,1....|[1.90491134608169...|[0.87044638395268...|       0.0|

|      0|     1|22.0|         4.0|       0|          1.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.26168391246747...|[0.77931584772929...|       0.0|

|      0|     1|22.0|         4.0|       0|          2.0|     18.0|       5.0|   5.0|[1.0,22.0,4.0,0.0...|[1.60422791737402...|[0.83260846569570...|       0.0|

|      0|     1|27.0|         4.0|       1|          3.0|     16.0|       5.0|   5.0|[1.0,27.0,4.0,1.0...|[1.48188645297092...|[0.81485734920851...|       0.0|

|      0|     1|27.0|         4.0|       1|          4.0|     14.0|       5.0|   4.0|[1.0,27.0,4.0,1.0...|[1.37900909774001...|[0.79883180985416...|       0.0|

|      0|     1|32.0|       0.125|       1|          2.0|     18.0|       5.0|   2.0|[1.0,32.0,0.125,1...|[0.28148664352576...|[0.56991065665974...|       0.0|

|      0|     1|32.0|        10.0|       1|          2.0|     20.0|       6.0|   3.0|[1.0,32.0,10.0,1....|[-0.1851761257948...|[0.45383780246566...|       1.0|

+-------+------+----+------------+--------+-------------+---------+----------+------+--------------------+--------------------+--------------------+----------+

only showing top 20 rows

 

 

 

// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier

// example

val trainingSummary = lrModel.summary

trainingSummary: org.apache.spark.ml.classification.LogisticRegressionTrainingSummary = org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary@4cde233d

 

 

// Obtain the objective per iteration.

val objectiveHistory = trainingSummary.objectiveHistory

objectiveHistory: Array[Double] = Array(0.5613118243072733, 0.5564125149222438, 0.5365395467216898, 0.5160918427628939, 0.51304621799159, 0.5105231964507352, 0.5079869547558363, 0.50728888730

31864, 0.5067113660796532, 0.506520677080951, 0.5059147658563949, 0.5053652033316485, 0.5047266888422277, 0.5045473900598205, 0.5041496504941453, 0.5034630545828777, 0.5025745763542784, 0.5019910559468922, 0.5012033102192196, 0.5009489760675826, 0.5008431925740259, 0.5008297629370251, 0.5008258245513862, 0.5008137617093257, 0.5008136785235711, 0.5008130045533166, 0.5008129888367148, 0.5008129675120628, 0.5008129469652479, 0.5008129168191972, 0.5008129132692991, 0.5008129124596163, 0.5008129124081014, 0.500812912251931, 0.5008129121356268)

objectiveHistory.foreach(loss => println(loss))

0.5613118243072733

0.5564125149222438

0.5365395467216898

0.5160918427628939

0.51304621799159

0.5105231964507352

0.5079869547558363

0.5072888873031864

0.5067113660796532

0.506520677080951

0.5059147658563949

0.5053652033316485

0.5047266888422277

0.5045473900598205

0.5041496504941453

0.5034630545828777

0.5025745763542784

0.5019910559468922

0.5012033102192196

0.5009489760675826

0.5008431925740259

0.5008297629370251

0.5008258245513862

0.5008137617093257

0.5008136785235711

0.5008130045533166

0.5008129888367148

0.5008129675120628

0.5008129469652479

0.5008129168191972

0.5008129132692991

0.5008129124596163

0.5008129124081014

0.500812912251931

0.5008129121356268

 

 

 lrModel.transform(testDF).select("features","rawPrediction","probability","prediction").show(30,false)

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|features                             |rawPrediction                               |probability                             |prediction|

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

|[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]|[3.0182997164210517,-3.0182997164210517]    |[0.9533940335539883,0.04660596644601167]|0.0       |

|[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]|[2.00632544907384,-2.00632544907384]        |[0.8814596114935873,0.11854038850641263]|0.0       |

|[0.0,27.0,1.5,0.0,2.0,16.0,6.0,5.0]  |[2.311142225292793,-2.311142225292793]      |[0.9097956387984996,0.09020436120150035]|0.0       |

|[0.0,27.0,4.0,1.0,3.0,18.0,4.0,5.0]  |[1.81918359677719,-1.81918359677719]        |[0.8604681362874618,0.13953186371253828]|0.0       |

|[0.0,27.0,7.0,1.0,2.0,18.0,1.0,5.0]  |[1.351091903842644,-1.351091903842644]      |[0.7943080837836515,0.20569191621634847]|0.0       |

|[0.0,27.0,7.0,1.0,3.0,16.0,1.0,4.0]  |[1.2482145486117338,-1.2482145486117338]    |[0.7769906379765039,0.2230093620234961] |0.0       |

|[0.0,27.0,10.0,1.0,2.0,12.0,1.0,4.0] |[0.6770360847975654,-0.6770360847975654]    |[0.6630768615308953,0.33692313846910465]|0.0       |

|[0.0,32.0,10.0,1.0,4.0,17.0,5.0,4.0] |[1.343039637398138,-1.343039637398138]      |[0.7929893642953615,0.20701063570463848]|0.0       |

|[0.0,32.0,10.0,1.0,5.0,14.0,4.0,5.0] |[2.220023246987134,-2.220023246987134]      |[0.9020332500408325,0.09796674995916752]|0.0       |

|[0.0,32.0,15.0,1.0,3.0,18.0,5.0,4.0] |[0.5532756896916551,-0.5532756896916551]    |[0.6348952415965647,0.3651047584034352] |0.0       |

|[0.0,37.0,15.0,1.0,4.0,17.0,1.0,5.0] |[1.7581459850319243,-1.7581459850319243]    |[0.8529773058286395,0.14702269417136052]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,9.0,5.0,5.0]  |[2.6088743974586124,-2.6088743974586124]    |[0.9314305415455806,0.06856945845441945]|0.0       |

|[0.0,52.0,15.0,1.0,5.0,12.0,1.0,3.0] |[1.8410975503955256,-1.8410975503955256]    |[0.8630784610725231,0.13692153892747697]|0.0       |

|[0.0,57.0,15.0,1.0,4.0,16.0,6.0,4.0] |[1.904911346081691,-1.904911346081691]      |[0.8704463839526814,0.1295536160473186] |0.0       |

|[1.0,22.0,4.0,0.0,1.0,18.0,5.0,5.0]  |[1.2616839124674724,-1.2616839124674724]    |[0.7793158477292919,0.22068415227070803]|0.0       |

|[1.0,22.0,4.0,0.0,2.0,18.0,5.0,5.0]  |[1.6042279173740237,-1.6042279173740237]    |[0.832608465695705,0.16739153430429493] |0.0       |

|[1.0,27.0,4.0,1.0,3.0,16.0,5.0,5.0]  |[1.4818864529709268,-1.4818864529709268]    |[0.8148573492085158,0.1851426507914842] |0.0       |

|[1.0,27.0,4.0,1.0,4.0,14.0,5.0,4.0]  |[1.379009097740017,-1.379009097740017]      |[0.7988318098541624,0.2011681901458377] |0.0       |

|[1.0,32.0,0.125,1.0,2.0,18.0,5.0,2.0]|[0.28148664352576547,-0.28148664352576547]  |[0.569910656659749,0.430089343340251]   |0.0       |

|[1.0,32.0,10.0,1.0,2.0,20.0,6.0,3.0] |[-0.1851761257948623,0.1851761257948623]    |[0.45383780246566996,0.5461621975343299]|1.0       |

|[1.0,32.0,10.0,1.0,4.0,20.0,6.0,4.0] |[0.9625930297088949,-0.9625930297088949]    |[0.7236406723848533,0.2763593276151468] |0.0       |

|[1.0,32.0,15.0,1.0,1.0,16.0,5.0,5.0] |[0.039440462424945366,-0.039440462424945366]|[0.5098588376463971,0.4901411623536029] |0.0       |

|[1.0,37.0,4.0,1.0,1.0,18.0,5.0,4.0]  |[0.7319377705508958,-0.7319377705508958]    |[0.6752303588678488,0.3247696411321513] |0.0       |

|[1.0,37.0,15.0,1.0,5.0,20.0,5.0,4.0] |[1.119955894572572,-1.119955894572572]      |[0.7539805352533917,0.24601946474660835]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,17.0,6.0,5.0] |[1.4276540623429193,-1.4276540623429193]    |[0.8065355283195409,0.19346447168045908]|0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,4.0,5.0] |[1.4935019453371354,-1.4935019453371354]    |[0.8166033137058254,0.1833966862941747] |0.0       |

|[1.0,42.0,15.0,1.0,4.0,20.0,6.0,3.0] |[0.4764020926318233,-0.4764020926318233]    |[0.6168979221749373,0.38310207782506256]|0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0] |[1.0201325344483316,-1.0201325344483316]    |[0.734998414766428,0.265001585233572]   |0.0       |

|[1.0,57.0,15.0,1.0,2.0,14.0,7.0,2.0] |[-0.04283609891898266,0.04283609891898266]  |[0.48929261249695394,0.5107073875030461]|1.0       |

|[1.0,57.0,15.0,1.0,5.0,20.0,5.0,3.0] |[1.4874352661557535,-1.4874352661557535]    |[0.8156930079647114,0.18430699203528864]|0.0       |

+-------------------------------------+--------------------------------------------+----------------------------------------+----------+

only showing top 30 rows

标签: LogisticRegression
共有 人打赏支持
粉丝 12
博文 73
码字总数 11113
×
hblt-j
如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!
* 金额(元)
¥1 ¥5 ¥10 ¥20 其他金额
打赏人
留言
* 支付类型
微信扫码支付
打赏金额:
已支付成功
打赏金额: