文档章节

PyTorch快速入门教程六(使用LSTM做图片分类)

earnpls
 earnpls
发布于 2017/07/02 09:18
字数 730
阅读 400
收藏 0

对于LSTM,我们要处理的数据是一个序列数据,对于图片而言,我们如何将其转换成序列数据呢?图片的大小是28x28,所以我们可以将其看成长度为28的序列,序列中的每个数据的维度是28,这样我们就可以将其变成一个序列数据了。

model

class Rnn(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layer, n_class):
        super(Rnn, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,
                            batch_first=True)
        self.classifier = nn.Linear(hidden_dim, n_class)

    def forward(self, x):
        # h0 = Variable(torch.zeros(self.n_layer, x.size(1),
                                #   self.hidden_dim)).cuda()
        # c0 = Variable(torch.zeros(self.n_layer, x.size(1),
                                #   self.hidden_dim)).cuda()
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = self.classifier(out)
        return out

model = Rnn(28, 128, 2, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

这里我们定义了一个LSTM模型,我们需要传入的参数是输入数据的维数28,LSTM输出的维数128,LSTM网络层数2层以及输出的类数10。

在网络定义里面首先需要定义LSTM,而长度为28的序列传入LSTM之后输出的也是长度为28,而输入的维数是28,输出的维数由我们定义为128,最后我们只取输出的最后一个部分传入分类器求出分类概率。

out = out[:, -1, :]通过这种方式,out中的三个维度分别表示batch_size,序列长度和数据维度,所以中间的序列长度取-1,表示取序列中的最后一个数据,这个数据维度为128,再通过分类器,输出10个结果表示每种结果的概率。

另外上面注释掉的部分就是初始的h_0和c_0,这里可以自己定义,如果不定义,默认传入0,也可以根据自己的要求传入自己定义的h_0和c_0。

开始训练

把训练过程的batch_size设置为100,learning_rate设置为0.01,训练20次,最后得到的结果如下

使用LSTM做图片分类

可以发现对于简单的图像分类RNN也能得到一个较好的结果,虽然CNN更多的用在图像领域而RNN更多的用在自然语言处理中。RNNCNN彼此优缺点可以自行百度。

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

本文转载自:https://ptorch.com/news/10.html

共有 人打赏支持
earnpls
粉丝 5
博文 26
码字总数 74
作品 0
昌平
程序员
PyTorch:60分钟入门学习

最近在学习PyTorch这个深度学习框架,在这里做一下整理分享给大家,有什么写的不对或者不好的地方,还请大侠们见谅啦~~~ 写在前面 本文就是主要是对PyTorch的安装,以及入门学习做了记录,...

与阳光共进早餐
01/15
0
0
教你几招搞定 LSTMS 的独门绝技(附代码)

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health,作者为 William Falcon 。 翻译...

雷锋字幕组
07/12
0
0
Keras vs PyTorch:谁是「第一」深度学习框架?

  选自Deepsense.ai   作者:Rafa Jakubanis、Piotr Migdal   机器之心编译   参与:路、李泽南、李亚洲      「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题...

机器之心
06/30
0
0
如何使用 TensorFlow mobile 将 PyTorch 和 Keras 部署到移动设备

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 Deploying PyTorch and Keras Models to Android with TensorFlow Mobile ,作者为 John Olafenwa 。 翻译 | 于志鹏 整理...

雷锋字幕组
07/12
0
0
业界 | 无缝整合PyTorch 0.4与Caffe2,PyTorch 1.0即将问世

  选自Facebook Research   作者:Bill Jia   机器之心编译   参与:思源、晓坤      在 F8 的第二天中,Facebook 正式宣布 PyTorch1.0 即将与大家见面,这是继一周前发布 0.4....

机器之心
05/03
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

[雪峰磁针石博客]软件测试专家工具包1web测试

web测试 本章主要涉及功能测试、自动化测试(参考: 软件自动化测试初学者忠告) 、接口测试(参考:10分钟学会API测试)、跨浏览器测试、可访问性测试和可用性测试的测试工具列表。 安全测试工具...

python测试开发人工智能安全
今天
2
0
JS:异步 - 面试惨案

为什么会写这篇文章,很明显不符合我的性格的东西,原因是前段时间参与了一个面试,对于很多程序员来说,面试时候多么的鸦雀无声,事后心里就有多么的千军万马。去掉最开始毕业干了一年的Jav...

xmqywx
今天
2
0
Win10 64位系统,PHP 扩展 curl插件

执行:1. 拷贝php安装目录下,libeay32.dll、ssleay32.dll 、 libssh2.dll 到 C:\windows\system32 目录。2. 拷贝php/ext目录下, php_curl.dll 到 C:\windows\system32 目录; 3. p...

放飞E梦想O
今天
0
0
谈谈神秘的ES6——(五)解构赋值【对象篇】

上一节课我们了解了有关数组的解构赋值相关内容,这节课,我们接着,来讲讲对象的解构赋值。 解构不仅可以用于数组,还可以用于对象。 let { foo, bar } = { foo: "aaa", bar: "bbb" };fo...

JandenMa
今天
1
0
OSChina 周一乱弹 —— 有人要给本汪介绍妹子啦

Osc乱弹歌单(2018)请戳(这里) 【今日歌曲】 @莱布妮子 :分享水木年华的单曲《中学时代》@小小编辑 手机党少年们想听歌,请使劲儿戳(这里) @须臾时光:夏天还在做最后的挣扎,但是晚上...

小小编辑
今天
48
8

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部