文档章节

在Ignite中使用线性回归算法

李玉珏
 李玉珏
发布于 11/22 00:24
字数 1480
阅读 90
收藏 1

在本系列前面的文章中,简单介绍了一下Ignite的机器学习网格,下面会趁热打铁,结合一些示例,深入介绍Ignite支持的一些机器学习算法。

如果要找合适的数据集,会发现可用的有很多,但是对于线性回归来说,一个非常好的备选数据集就是房价,可以非常方便地从UCI网站获取合适的数据

在本文中会训练一个线性回归模型,并且计算R2得分。

需要先准备一些数据,并且要将数据转换成Ignite支持的格式,这通常是数据科学家需要花时间做的事。

首先,需要获取原始数据并将其拆分成训练数据(80%)和测试数据(20%)。Ignite暂时还不支持专用的数据拆分,路线图中的未来版本会支持这个功能。但是就目前来说有许多可用的免费和开源工具可以执行这样的数据拆分,或者也可以用一种Ignite支持的编程语言自己编写这种代码。在本文中会使用下面自己编写的代码来实现此任务:

from sklearn import datasets
import pandas as pd

# Load Boston housing dataset.
boston_dataset = datasets.load_boston()
x = boston_dataset.data
y = boston_dataset.target

# Split it into train and test subsets.
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=23)

# Save train set.
train_ds = pd.DataFrame(x_train, columns=boston_dataset.feature_names)
train_ds["TARGET"] = y_train
train_ds.to_csv("boston-housing-train.csv", index=False, header=None)
# Save test set.
test_ds = pd.DataFrame(x_test, columns=boston_dataset.feature_names)
test_ds["TARGET"] = y_test
test_ds.to_csv("boston-housing-test.csv", index=False, header=None)

# Train linear regression model.
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(x_train, y_train)

# Score result model.
lr.score(x_test, y_test)

这段代码从UCI网站上获取可用的数据集,执行了数据的拆分,然后计算了R2得分。返回值为0.745021053016975,或者为74.5%,之后会将此值与Ignite的进行对比。

当训练和测试数据准备好之后,就可以写应用了,本文的算法是:

  1. 读取训练数据和测试数据;
  2. 在Ignite中保存训练数据和测试数据;
  3. 使用训练数据拟合线性回归模型;
  4. 将模型应用于测试数据;
  5. 确定模型的R2得分。

由于数据集非常小,可以将其加载到标准Java数据结构中,并直接从Java程序中运行线性回归。或者,也可以将数据加载到Ignite存储中,然后对存储的数据进行线性回归。使用Ignite存储的优点是数据将分布在整个集群中,因此将执行分布式训练。对于大规模数据集,使用Ignite存储就会有很大的好处。在本例中将把数据加载到Ignite存储中。

读取训练数据和测试数据

需要读取两个CSV文件,一个是训练数据,一个是测试数据。通过下面的代码,可以从CSV文件中读取数据:

private static void loadData(String fileName, IgniteCache<Integer, HouseObservation> cache)
        throws FileNotFoundException {

   Scanner scanner = new Scanner(new File(fileName));

   int cnt = 0;
   while (scanner.hasNextLine()) {
      String row = scanner.nextLine();
      String[] cells = row.split(",");
      double[] features = new double[cells.length - 1];

      for (int i = 0; i < cells.length - 1; i++)
         features[i] = Double.valueOf(cells[i]);
      double price = Double.valueOf(cells[cells.length - 1]);

      cache.put(cnt++, new HouseObservation(features, price));
   }
}

该代码简单地一行行的读取数据,然后对于每一行,使用CSV的分隔符拆分出字段,每个字段之后将转换成double类型并且存入Ignite。

将训练数据和测试数据存入Ignite

前面的代码将数据存入Ignite,要使用这个代码,首先要创建Ignite存储,如下:

IgniteCache<Integer, HouseObservation> trainData = ignite.createCache("BOSTON_HOUSING_TRAIN");
IgniteCache<Integer, HouseObservation> testData = ignite.createCache("BOSTON_HOUSING_TEST");

使用训练数据创建线性回归模型

数据存储之后,可以像下面这样创建训练器:

DatasetTrainer<LinearRegressionModel, Double> trainer = new LinearRegressionLSQRTrainer();

然后拟合训练数据,如下:

LinearRegressionModel mdl = trainer.fit(
   ignite,
   trainData,
   (k, v) -> v.getFeatures(),  
// Feature extractor.

   (k, v) -> v.getPrice()
// Label extractor.

Ignite将数据保存为键-值(K-V)格式,因此上面的代码使用了值部分,目标值是Price,而特征位于其他列中。

将模型应用于测试数据

下一步,就可以用训练好的线性模型测试测试数据了,在Ignite的机器学习路线图中,有计划提供内置的得分计算器,但是就目前来说,可以这样做:

double meanPrice = getMeanPrice(testData);
double u = 0, v = 0;

try (QueryCursor<Cache.Entry<Integer, HouseObservation>> cursor = testData.query(new ScanQuery<>())) {
   for (Cache.Entry<Integer, HouseObservation> testEntry : cursor) {
      HouseObservation observation = testEntry.getValue();

      double realPrice = observation.getPrice();
      double predictedPrice = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

      u += Math.pow(realPrice - predictedPrice, 2);
      v += Math.pow(realPrice - meanPrice, 2);
   }
}

这里计算的是残差平方和(U)和总平方和(V)。

确定模型的R2得分

可以发现,R2的值为1 - u / v:

double score = 1 - u / v;

System.out.println("Score : " + score);

输出值为0.7450194305206714,或者74.5%,这与之前的值相同。

总结

Apache Ignite提供了一个机器学习算法库。通过线性回归示例,可以看到创建模型、测试模型和确定模型的R2得分的简单性,也可以用这个模型来做预测。

目前,可用的机器学习工具有很多,但它们不能多节点扩展,只能处理少量数据。相比之下,Ignite所带来的好处是它有能力扩展下面两种能力:

  1. 集群的大小(成百上千台机器)
  2. 存储的数据量(GB、TB甚至PB级数据)

因此,Ignite可以大规模地运行机器学习。它可以以分布式处理的方式,对大数据进行真正的机器学习管理。

在机器学习系列的下一篇中,将研究另一种机器学习算法。敬请期待!

© 著作权归作者所有

共有 人打赏支持
李玉珏

李玉珏

粉丝 304
博文 66
码字总数 110992
作品 0
沈阳
技术主管
私信 提问
Apache Ignite 2.5.0 版本发布,千级节点伸缩性

Apache Ignite 2.5: 千级节点伸缩性 Apache Ignite的用户通常知道的两个关键点是-扩展性和性能。在很多分布式系统的整个生命周期中,通常会不停地改进性能,而对扩展性相关的改进次数,会比较...

李玉珏
06/01
1K
10
内存数据组织 - Apache Ignite

1.Ignite是什么? Apache Ignite是一个以内存为中心的分布式数据库、缓存和处理平台,支持事务、分析以及流式负载,可以在PB级数据上享有内存级的性能。 1.1.Ignite定位 Ignite是不是内存数据...

匿名
2015/01/10
0
8
在Ignite中使用k-最近邻(k-NN)分类算法

在本系列前面的文章中,简单介绍了一下Ignite的线性回归算法,下面会尝试另一个机器学习算法,即k-最近邻(k-NN)分类。该算法基于对象k个最近邻中最常见的类来对对象进行分类,可用于确定类成...

李玉珏
11/28
0
0
Apache Ignite 2.1.0 版本发布,全新的持久化存储

社区宣布,Apache Ignite 2.1.0版本正式发布。 这个版本包括了一个捐赠来的全新特性-Ignite持久化存储,他具有完全的内存持久化架构,使得应用同时具有基于内存的高性能以及基于磁盘的持久化...

李玉珏
2017/07/28
1K
6
全面对比,深度解析 Ignite 与 Spark

经常有人拿 Ignite 和 Spark 进行比较,然后搞不清两者的区别和联系。Ignite 和 Spark,如果笼统归类,都可以归于内存计算平台,然而两者功能上虽然有交集,并且 Ignite 也会对 Spark 进行支...

编辑部的故事
09/13
0
0

没有更多内容

加载失败,请刷新页面

加载更多

什么是以太坊DAO?(二)

Decentralized Autonomous Organization,简称DAO,以太坊中重要的概念。一般翻译为去中心化的自治组织。 在上一节中,我们为了展示什么是DAO创建了一个合约,就像一个采用邀请制的俱乐部,会...

geek12345
25分钟前
1
0
全屋WiFi彻底无死角 这才是终极解决方案

无线网络现在不仅在家庭中不可或缺,在酒店、医院、学校等场景中的需求也越来越多。尤其是这些场景中,房间多但也需要每个房间都能够完美覆盖WiFi,传统的吸顶式AP就无法很好的解决问题。 H3...

linux-tao
39分钟前
4
0
Python日期字符串比较

需要用python的脚本来快速检测一个文件内的二个时间日期字符串的大小,其实实现很简单,首先一些基础的日期格式化知识如下 复制代码 %a星期的简写。如 星期三为Web %A星期的全写。如 星期三为...

dragon_tech
39分钟前
3
0
ORA 各种oraclesql错误

ORA-00001: 违反唯一约束条件 (.) ORA-00017: 请求会话以设置跟踪事件 ORA-00018: 超出最大会话数 ORA-00019: 超出最大会话许可数 ORA-00020: 超出最大进程数 () ORA-00021: 会话附属于其它某...

青峰Jun19er
43分钟前
3
0
没错,老板让我写个 BUG!

前言 标题没有看错,真的是让我写个 bug! 刚接到这个需求时我内心没有丝毫波澜,甚至还有点激动。这可是我特长啊;终于可以光明正大的写 bug 了🙄。 先来看看具体是要干啥吧,其实主要就是...

crossoverJie
56分钟前
118
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部