文档章节

用Mxnet对California房地产数据做线性回归分析

q
 qinhui99
发布于 2017/03/18 10:22
字数 594
阅读 179
收藏 0

改写自去年写的一篇博客《用Mxnet和Tensorflow对California房地产数据做分析》。加了一些新的领悟。

Mxnet和Tensorflow都是我正在学习的东西,因为不熟,所以想多做些练习来加深理解。于是就用California房地产数据来练练手。

这里的California房地产数据引用自sklearn自带的数据集(2万多条数据)。该数据集比正规的版本做了简化,只有9个字段。 
其中,特征属性有8个。 
feature_names = ["MedInc""HouseAge""AveRooms","AveBedrms",
                 "Population""AveOccup""Latitude","Longitude"]
该数据的目标值是该区域的平均房价。也就是说需要预测的是精确的平均房价,而不是分类标签。具体代码如下:
# coding=utf-8
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
import numpy as np
import mxnet as mx
from sklearn.utils import shuffle
import logging

housing = fetch_california_housing()
m, n = housing.data.shape
housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data]

scaler = StandardScaler()
scaled_housing_data = scaler.fit_transform(housing.data)
scaled_housing_data_plus_bias = np.c_[np.ones((m, 1)), scaled_housing_data]

n_epochs = 15
learning_rate = 0.025
#刚开始的时候,batch_size越大,程序速度就越快。经过一个阈值后,就没多大效果了。而且最终的MSE也没有太多改进。
batch_size =128
# shuffle data
#X, y = shuffle(scaled_housing_data_plus_bias, housing.target)
X, y = (scaled_housing_data_plus_bias, housing.target)

# 定义符号
x_sym = mx.symbol.Variable('data')
y_sym = mx.symbol.Variable('softmax_label')

# 定义网络。这里可以根据需要定义多层网络,例如下面定义了3层网络,最后输出到线性回归处理器里
fc1 = mx.symbol.FullyConnected(data=x_sym, num_hidden=40, name='pre')
act1 = mx.symbol.Activation(data = fc1, name='act1', act_type="relu")
fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden=20)
act2 = mx.symbol.Activation(data = fc2, name='act2', act_type="relu")
fc3 = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=1)
# mxnet自带的线性回归
loss = mx.symbol.LinearRegressionOutput(data=fc3,label=y_sym, name='loss')

# 定义模型
model = mx.model.FeedForward(
    ctx=mx.cpu(), symbol=loss, num_epoch=n_epochs,
    learning_rate=learning_rate,
    optimizer='adam'
)
logging.basicConfig(level=logging.INFO)

# Build iterator
slice_index=20500
train_iter = mx.io.NDArrayIter(data=X[:18000], label=y[:18000], batch_size=batch_size, shuffle=True)
eval_iter = mx.io.NDArrayIter(data=X[18000:slice_index], label=y[18000:slice_index], batch_size=batch_size, shuffle=True)

test_iter=mx.io.NDArrayIter(data=X[slice_index:slice_index+1],  shuffle=False)

#训练时打印出mse和rmse指标
eval_metrics = ['mse']
eval_metrics.append('rmse')
#训练
model.fit(X = train_iter,
          eval_metric=eval_metrics,
          eval_data=eval_iter)
#预测测试例子
r=model.predict(test_iter)
#打印预测结果和真实值,对比看看
print (r,y[slice_index:slice_index+1])

因为使用了多层神经网络训练,训练结果比原来使用一层的神经网络看起来要好些,MSE指标从0.54下降到0.34左右。所以,神经网络深一些,模型的效果确实会好一些。

© 著作权归作者所有

q
粉丝 66
博文 73
码字总数 34091
作品 0
深圳
程序员
私信 提问
MXNet 宣布支持 Keras 2,可更加方便快捷地实现 CNN 及 RNN 分布式训练

雷锋网(公众号:雷锋网) AI 研习社按,近期,AWS 表示 MXNet 支持 Keras 2,开发者可以使用 Keras-MXNet 更加方便快捷地实现 CNN 及 RNN 分布式训练。AI 研习社将 AWS 官方博文编译如下。 Ke...

孔令双
2018/05/23
0
0
mxnet训练模型、导出模型、加载模型 进行预测(python和C++)

版权声明:原创文章如需转载,请在左侧博主描述栏目扫码联系我并取得授权,谢谢 https://blog.csdn.net/u012234115/article/details/80656030 mxnet支持将已训练的模型导出成网络和参数分离的...

踏莎行hyx
2018/06/11
0
0
业界 | MXNet开放支持Keras,高效实现CNN与RNN的分布式训练

  选自AWS Machine Learning Blog   作者:Lai Wei、Kalyanee Chendke、Aaron Markham、Sandeep Krishnamurthy   机器之心编译   参与:路、王淑婷      今日 AWS 发布博客宣布 ...

机器之心
2018/05/22
0
0
云上深度学习实践(二)-云上MXNet实践

目录 云上深度学习实践(一)-GPU云服务器TensorFlow单机多卡训练性能实践 云上深度学习实践(二)-云上MXNet实践 1 MXNet 简介 1.1 MXNet特点 MXNet是一个全功能,灵活可编程和高扩展性的深...

撷峰
2018/07/13
0
0
机器学习和深度学习的最佳框架大比拼

在过去的一年里,咱们讨论了六个开源机器学习和/或深度学习框架:Caffe,Microsoft Cognitive Toolkit(又名CNTK 2),MXNet,Scikit-learn,Spark MLlib和TensorFlow。如果把网撒得大些,可...

凝小紫
2017/02/05
29.5K
4

没有更多内容

加载失败,请刷新页面

加载更多

mysql-connector-java升级到8.0后保存时间到数据库出现了时差

在一个新项目中用到了新版的mysql jdbc 驱动 <dependency>     <groupId>mysql</groupId>     <artifactId>mysql-connector-java</artifactId>     <version>8.0.18</version> ......

ValSong
52分钟前
5
0
Spring Boot 如何部署到 Linux 中的服务

打包完成后的 Spring Boot 程序如何部署到 Linux 上的服务? 你可以参考官方的有关部署 Spring Boot 为 Linux 服务的文档。 文档链接如下: https://docs.ossez.com/spring-boot-docs/docs/r...

honeymoose
54分钟前
5
0
Spring Boot 2 实战:使用 Spring Boot Admin 监控你的应用

1. 前言 生产上对 Web 应用 的监控是十分必要的。我们可以近乎实时来对应用的健康、性能等其他指标进行监控来及时应对一些突发情况。避免一些故障的发生。对于 Spring Boot 应用来说我们可以...

码农小胖哥
今天
6
0
ZetCode 教程翻译计划正式启动 | ApacheCN

原文:ZetCode 协议:CC BY-NC-SA 4.0 欢迎任何人参与和完善:一个人可以走的很快,但是一群人却可以走的更远。 ApacheCN 学习资源 贡献指南 本项目需要校对,欢迎大家提交 Pull Request。 ...

ApacheCN_飞龙
今天
4
0
CSS定位

CSS定位 relative相对定位 absolute绝对定位 fixed和sticky及zIndex relative相对定位 position特性:css position属性用于指定一个元素在文档中的定位方式。top、right、bottom、left属性则...

studywin
今天
7
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部