文档章节

spark 线性回归算法(scala)

o
 osc_odyg6b92
发布于 2018/07/13 14:17
字数 942
阅读 10
收藏 0

「深度学习福利」大神带你进阶工程师,立即查看>>>

构建Maven项目,托管jar包

数据格式

//0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label

(fpid_10000201 115717 (2239 173 2011-07-12 00:00:00.0 2016-08-31 15:40:37.0 4123.08 700.92 4824.0 201704 2017-04-25 N) 201706 可视回油单向阀 HYS-1Φ1.5A 只 3.0 35.8974358974359 107.69 18.31 1090120040000000000) 0)
(fpid_10000324 253389 (7310 173 2016-01-04 00:00:00.0 2017-07-24 10:01:02.0 36609.76 6223.64 42833.4 201709 2017-09-08 N) 201711 电视机 三星743寸 台 1.0 2991.4529914529912 2991.45 508.55 1090522010000000000) 0)
(fpid_10000416 126378 (5175 173 1999-01-14 00:00:00.0 2016-05-27 14:50:49.0 25337.81 4307.39 29645.2 201612 2016-12-21 N) 201706 防水涂料 null 公斤 105.0 5.225885225885226 548.72 93.28 1070101060000000000) 0)

package Test.tett1

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.regression.LinearRegressionModel
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.regression.LinearRegression

object MLDemo3 {
  
  def main(args: Array[String]): Unit = {
                val sess = SparkSession.builder().appName("ml").master("local[4]").getOrCreate();
                val sc = sess.sparkContext;
                val dataDir = "hdfs://weekend110:9000/user/hive/warehouse/nsr2_xfp"
                //定义样例类(要分析数据的类属性)
        case class FP(fp_nid:String,nsr_id:String,gf_id:String,hydm:String,djzclx_dm:String,kydjrq:String,xgrq:String,
                    je:String,se:String,jshj:String,kpyf:String,kprq:String,zfbz:String,
                    label:String)

                //变换()
               //0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label
              val fpDataRDD = sc.textFile(dataDir).map(_.split("\001")).map(f => FP(f(0).toString, 
              f(1).toString,f(2).toString,f(3).toString,f(4).toString,f(5).toString,f(6).toString, 
                f(7).toString, f(8).toString,f(9).toString,f(10).toString,f(11).toString,f(12).toString,
                f(13).toString))

                import sess.implicits._
            
                def strToDouble(str: String): Double = {
          val regex = """([0-9]+)""".r
          val res = str match{
            case regex(num) => num
            case _ => "1"
          }
          val resDouble = res.toDouble
          resDouble
        }
                
                //转换RDD成DataFrame
                //1.fp_nid 2.nsr_id 3.gf_id 4.zfbz 5.hydm 6.djzclx_dm 7.je 8.se 9.jshj 10.kpyf 11.date_key 12.sl 13.dj 14.je1 15.se1 16.spbm
                val trainingDF = fpDataRDD.map(f => (f.label.replaceAll("[)]","").toDouble,
                    Vectors.dense( 
                    if(f.zfbz.equals("N)")) 1 else 0,
                    f.hydm.replaceAll("[(]","").toDouble,
                    f.djzclx_dm.toDouble,
                    f.kpyf.toDouble,
                    strToDouble(f.je),
            strToDouble(f.se),
            strToDouble(f.jshj)
            ))).toDF("label", "features")    
                        
                //显式数据
                trainingDF.show()
                println("======================")

                //创建线性回归对象
                val lr = new LinearRegression()
                //设置最大迭代次数
                lr.setMaxIter(50)
                //通过线性回归拟合训练数据,生成模型
                val model = lr.fit(trainingDF)

                //创建内存测试数据数据框
          val testDF = sess.createDataFrame(Seq(
                    (0,Vectors.dense(3812,171,9401.71,1598.29,11000.0,201612,1)),
                    (0,Vectors.dense(4190,173,72200.0,12274.0,84474.0,201710,1)),
                    (0,Vectors.dense(7519,173,99999.99,3000.0,102999.99,201709,1)),
                    
                    (1,Vectors.dense(1951,173,19743.59,3356.41,23100.0,201612,1)),
                    (1,Vectors.dense(5219,173,41880.35,7119.65,49000.0,201705,1)),
                    (1,Vectors.dense(5189,173,1320.93,224.56,1545.49,201611,1)),    
                    (1,Vectors.dense(1779,173,21911.4,3724.94,25636.34,201611,0))
                )).toDF("label", "features")
                
                testDF.show()

                //创建临时视图
                testDF.createOrReplaceTempView("test")
                println("======================")
                
                //利用model对测试数据进行变化,得到新数据框,查询features", "label", "prediction方面值。        
                val tested = model.transform(trainingDF).select("features", "label", "prediction");
                tested.show();
            
                //将分析的数据导入数据库                
                import java.sql.DriverManager
              tested.rdd.foreachPartition(
                it =>{
                      var url = "jdbc:mysql://localhost:3306/data?useUnicode=true&characterEncoding=utf8"
                      val conn= DriverManager.getConnection(url,"root","123456")
                      val pstat = conn.prepareStatement ("INSERT INTO `test` (`label`, `pre`,`zfbz`,`hydm`, `djzclx_dm`, "
                                                        +"`kpyf`,`je`,`se`,`jshj`) "
                                                        +"VALUES (?,?,?,?,?,?,?,?,?)")
                      for (obj <-it){
                          pstat.setString(1,obj.get(1).toString())
                          pstat.setString(2,obj.get(2).toString())
                          pstat.setString(3,obj.get(0).toString().split(",")(0).replaceAll("[\\[]", ""))
                          pstat.setString(4,obj.get(0).toString().split(",")(1))
                          pstat.setString(5,obj.get(0).toString().split(",")(2))
                          pstat.setString(6,obj.get(0).toString().split(",")(3))
                          pstat.setString(7,obj.get(0).toString().split(",")(4))
                          pstat.setString(8,obj.get(0).toString().split(",")(5))
                          pstat.setString(9,obj.get(0).toString().split(",")(6) .replaceAll("[\\]]", ""))
                          pstat.addBatch
                      }
                      try{
                          pstat.executeBatch
                      }finally{
                          pstat.close
                          conn.close
                      }
                 }
            )    
            }
}

maven的pom.xml配置文件

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>Test</groupId>
  <artifactId>tett1</artifactId>
  <version>0.0.1-SNAPSHOT</version>
  <inceptionYear>2008</inceptionYear>
  <properties>
    <scala.version>2.7.0</scala.version>
  </properties>

 <repositories>
    <repository>
      <id>scala-tools.org</id>
      <name>Scala-Tools Maven2 Repository</name>
      <url>http://scala-tools.org/repo-releases</url>
    </repository>
  </repositories>

  <pluginRepositories>
    <pluginRepository>
      <id>scala-tools.org</id>
      <name>Scala-Tools Maven2 Repository</name>
      <url>http://scala-tools.org/repo-releases</url>
    </pluginRepository>
  </pluginRepositories>

  <dependencies>
   <!--  <dependency>
      <groupId>org.scala-lang</groupId>
      <artifactId>scala-library</artifactId>
      <version>${scala.version}</version>
    </dependency> -->
     <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.1.0</version>
     </dependency>
  </dependencies>

  <build>
    <sourceDirectory>src/main/scala</sourceDirectory>
    <testSourceDirectory>src/test/scala</testSourceDirectory>
    <pluginManagement>
    <plugins>
       <plugin>
         <groupId>org.apache.maven.plugins</groupId>
          <artifactId>maven-surefire-plugin</artifactId>
          <configuration>
          <skip>true</skip>
         </configuration>
       </plugin> 
    
      <plugin>
        <groupId>org.scala-tools</groupId>
        <artifactId>maven-scala-plugin</artifactId>
        <executions>
          <execution>
            <goals>
              <goal>compile</goal>
              <goal>testCompile</goal>
            </goals>
          </execution>
        </executions>
        <configuration>
          <scalaVersion>${scala.version}</scalaVersion>
          <args>
            <arg>-target:jvm-1.5</arg>
          </args>
        </configuration>
      </plugin>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-eclipse-plugin</artifactId>
        <configuration>
          <downloadSources>true</downloadSources>
          <buildcommands>
            <buildcommand>ch.epfl.lamp.sdt.core.scalabuilder</buildcommand>
          </buildcommands>
          <additionalProjectnatures>
            <projectnature>ch.epfl.lamp.sdt.core.scalanature</projectnature>
          </additionalProjectnatures>
          <classpathContainers>
            <classpathContainer>org.eclipse.jdt.launching.JRE_CONTAINER</classpathContainer>
            <classpathContainer>ch.epfl.lamp.sdt.launching.SCALA_CONTAINER</classpathContainer>
          </classpathContainers>
        </configuration>
      </plugin>
    </plugins>
    </pluginManagement>
  </build>
  <reporting>
    <plugins>
      <plugin>
        <groupId>org.scala-tools</groupId>
        <artifactId>maven-scala-plugin</artifactId>
        <configuration>
          <scalaVersion>${scala.version}</scalaVersion>
        </configuration>
      </plugin>
    </plugins>
  </reporting>
</project>

 

o
粉丝 1
博文 500
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。
REST/HTTP 工具包--Spray

Spray 是一个开源的 REST/HTTP 工具包和底层网络 IO 包,基于 Scala 和 Akka 构建。轻量级、异步、非堵塞、基于 actor 模式、模块化和可测试是 spray 的特点。 示例代码: val responses: F...

匿名
2013/02/20
7.1K
0
分布式计算框架--DPark

DPark 是 Spark 的 Python 克隆,是一个Python实现的分布式计算框架,可以非常方便地实现大规模数据处理和迭代计算。 DPark 由豆瓣实现,目前豆瓣内部的绝大多数数据分析都使用DPark 完成,正...

Davies
2013/06/06
3.6K
1
scala的orm框架--srom

scala的orm框架,相比其他orm更为简洁 // Declare a model:case class Artist( name : String, genres : Set[Genre] )case class Genre( name : String ) // Initialize SORM, automaticall......

livehl
2013/06/18
1K
0
Akka实战:分散、聚合模式

分散与聚合:简单说就是一个任务需要拆分成多个小任务,每个小任务执行完后再把结果聚合在一起返回。 代码 http://git.oschina.net/yangbajing/akka-action 实例背景 本实例来自一个真实的线...

羊八井
2015/11/26
3.7K
13
Spark数据挖掘-深入GraphX(1)

Spark数据挖掘-深入GraphX(1) 1 网络数据集 当图被用来描述系统中的组件之间的交互关系的时候,图可以被用来表示任何系统。图原理提供了通用的语言和一系列工具来表示和分析复杂的系统。简单...

clebeg
2015/11/26
970
2

没有更多内容

加载失败,请刷新页面

加载更多

SQL 语句大全

点击上方“掌上编程”,选择“置顶或者星标” 优质文章第一时间送达! 一、基础 「1、说明:创建数据库」 CREATE DATABASE database-name    「2、说明:删除数据库」 drop database ...

GeneralMa
昨天
0
0
山东创睦网络科技有限公司:使用Python爬取全球新冠肺炎疫情数据

使用Python爬取全球新冠肺炎疫情数据 导入所需库包 获取实时数据的url 正式编写程序 查看输出结果 导入所需库包 在获取数据之前,我们需要先安装好所需的包requests和pandas: 1.如果是使用p...

osc_qv1fwke0
28分钟前
14
0
如何1年获得别人3年的工作经验(深度好文)

最近有同学问我,为什么你的工作年限不长,技术却这么厉害,我笑了笑,啥也没说。 我不是不想回答,是不知道怎么回答。在他们的定位可能就是,每方面都懂一点,遇到问题能够快速解决,就是比...

zhang_rick
今天
0
0
新基建带动行业

什么是“新基建”? 什么是“新基建”? 根据央视发布的信息来看,其涵盖了5G基站建设、新能源汽车充电桩、大数据中心、人工智能、工业互联网,特高压,城际以及城轨交通,涉及了七大领域和相...

osc_anefoz50
28分钟前
0
0
怕入错行?这群技术人写了本“择业指南”

计算机专业好找工作吗?哪些方向是当前的主流和热门方向呢? 计算机专业的你是不是还在为职业发展纠结犹豫呢? 刚经历完高考选专业的你是不是还在迷茫徘徊呢? 那么福利来啦! 《软件技术职业...

阿里云云栖号
28分钟前
21
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部