Spark2.2.0 MLlib RandomForestClassifier

原创
2021/05/28 19:55
阅读数 264

合并特征

trainData, testData = data.randomSplit([0.8, 0.2])
featuresArray = data.columns[:-1]
assembler = VectorAssembler().setInputCols(featuresArray).setOutputCol("features")

构建模型

# 创建随机森林
RF = RandomForestClassifier().setLabelCol("label").setFeaturesCol("features")
# 流水线
Pipeline = Pipeline().setStages([assembler,RF])

训练预测

# 训练逻辑回归模型
model = Pipeline.fit(trainData)
# 预测逻辑回归的值
prediction = model.transform(testData)

准确率评估

# 模型评估--准确率 
evaluator1 = MulticlassClassificationEvaluator().setMetricName("accuracy")
ACC = evaluator1.evaluate(prediction)
print("Accuracy:",ACC)

AUC评估

# 模型评估--AUC
evaluator2 = BinaryClassificationEvaluator().setMetricName("areaUnderROC").setRawPredictionCol("rawPrediction").setLabelCol("label")
AUC = evaluator2.evaluate(prediction)
print("Area Under ROC:",AUC)

 

 

 

 

 

 

 

 

 

 

展开阅读全文
加载中

作者的其它热门文章

打赏
0
0 收藏
分享
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部