# 在Ignite中使用线性回归算法

2018/11/22 00:24

from sklearn import datasets
import pandas as pd

x = boston_dataset.data
y = boston_dataset.target

# Split it into train and test subsets.
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=23)

# Save train set.
train_ds = pd.DataFrame(x_train, columns=boston_dataset.feature_names)
train_ds["TARGET"] = y_train
# Save test set.
test_ds = pd.DataFrame(x_test, columns=boston_dataset.feature_names)
test_ds["TARGET"] = y_test

# Train linear regression model.
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(x_train, y_train)

# Score result model.
lr.score(x_test, y_test)

1. 读取训练数据和测试数据；
2. 在Ignite中保存训练数据和测试数据；
3. 使用训练数据拟合线性回归模型；
4. 将模型应用于测试数据；
5. 确定模型的R2得分。

## 读取训练数据和测试数据

private static void loadData(String fileName, IgniteCache<Integer, HouseObservation> cache)
throws FileNotFoundException {

Scanner scanner = new Scanner(new File(fileName));

int cnt = 0;
while (scanner.hasNextLine()) {
String row = scanner.nextLine();
String[] cells = row.split(",");
double[] features = new double[cells.length - 1];

for (int i = 0; i < cells.length - 1; i++)
features[i] = Double.valueOf(cells[i]);
double price = Double.valueOf(cells[cells.length - 1]);

cache.put(cnt++, new HouseObservation(features, price));
}
}

## 将训练数据和测试数据存入Ignite

IgniteCache<Integer, HouseObservation> trainData = ignite.createCache("BOSTON_HOUSING_TRAIN");
IgniteCache<Integer, HouseObservation> testData = ignite.createCache("BOSTON_HOUSING_TEST");

## 使用训练数据创建线性回归模型

DatasetTrainer<LinearRegressionModel, Double> trainer = new LinearRegressionLSQRTrainer();

LinearRegressionModel mdl = trainer.fit(
ignite,
trainData,
(k, v) -> v.getFeatures(),
// Feature extractor.

(k, v) -> v.getPrice()
// Label extractor.

Ignite将数据保存为键-值(K-V)格式，因此上面的代码使用了值部分，目标值是Price，而特征位于其他列中。

## 将模型应用于测试数据

double meanPrice = getMeanPrice(testData);
double u = 0, v = 0;

try (QueryCursor<Cache.Entry<Integer, HouseObservation>> cursor = testData.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, HouseObservation> testEntry : cursor) {
HouseObservation observation = testEntry.getValue();

double realPrice = observation.getPrice();
double predictedPrice = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

u += Math.pow(realPrice - predictedPrice, 2);
v += Math.pow(realPrice - meanPrice, 2);
}
}

## 确定模型的R2得分

double score = 1 - u / v;

System.out.println("Score : " + score);

## 总结

Apache Ignite提供了一个机器学习算法库。通过线性回归示例，可以看到创建模型、测试模型和确定模型的R2得分的简单性，也可以用这个模型来做预测。

1. 集群的大小（成百上千台机器）
2. 存储的数据量（GB、TB甚至PB级数据）

0 评论
5 收藏
1