文档章节

PyTorch快速入门教程六(使用LSTM做图片分类)

earnpls
 earnpls
发布于 2017/07/02 09:18
字数 730
阅读 421
收藏 0

对于LSTM,我们要处理的数据是一个序列数据,对于图片而言,我们如何将其转换成序列数据呢?图片的大小是28x28,所以我们可以将其看成长度为28的序列,序列中的每个数据的维度是28,这样我们就可以将其变成一个序列数据了。

model

class Rnn(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layer, n_class):
        super(Rnn, self).__init__()
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,
                            batch_first=True)
        self.classifier = nn.Linear(hidden_dim, n_class)

    def forward(self, x):
        # h0 = Variable(torch.zeros(self.n_layer, x.size(1),
                                #   self.hidden_dim)).cuda()
        # c0 = Variable(torch.zeros(self.n_layer, x.size(1),
                                #   self.hidden_dim)).cuda()
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = self.classifier(out)
        return out

model = Rnn(28, 128, 2, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

这里我们定义了一个LSTM模型,我们需要传入的参数是输入数据的维数28,LSTM输出的维数128,LSTM网络层数2层以及输出的类数10。

在网络定义里面首先需要定义LSTM,而长度为28的序列传入LSTM之后输出的也是长度为28,而输入的维数是28,输出的维数由我们定义为128,最后我们只取输出的最后一个部分传入分类器求出分类概率。

out = out[:, -1, :]通过这种方式,out中的三个维度分别表示batch_size,序列长度和数据维度,所以中间的序列长度取-1,表示取序列中的最后一个数据,这个数据维度为128,再通过分类器,输出10个结果表示每种结果的概率。

另外上面注释掉的部分就是初始的h_0和c_0,这里可以自己定义,如果不定义,默认传入0,也可以根据自己的要求传入自己定义的h_0和c_0。

开始训练

把训练过程的batch_size设置为100,learning_rate设置为0.01,训练20次,最后得到的结果如下

使用LSTM做图片分类

可以发现对于简单的图像分类RNN也能得到一个较好的结果,虽然CNN更多的用在图像领域而RNN更多的用在自然语言处理中。RNNCNN彼此优缺点可以自行百度。

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

本文转载自:https://ptorch.com/news/10.html

共有 人打赏支持
earnpls
粉丝 5
博文 26
码字总数 74
作品 0
昌平
程序员
PyTorch:60分钟入门学习

最近在学习PyTorch这个深度学习框架,在这里做一下整理分享给大家,有什么写的不对或者不好的地方,还请大侠们见谅啦~~~ 写在前面 本文就是主要是对PyTorch的安装,以及入门学习做了记录,...

与阳光共进早餐
01/15
0
0
教你几招搞定 LSTMS 的独门绝技(附代码)

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 Taming LSTMs: Variable-sized mini-batches and why PyTorch is good for your health,作者为 William Falcon 。 翻译...

雷锋字幕组
07/12
0
0
Keras vs PyTorch:谁是「第一」深度学习框架?

  选自Deepsense.ai   作者:Rafa Jakubanis、Piotr Migdal   机器之心编译   参与:路、李泽南、李亚洲      「第一个深度学习框架该怎么选」对于初学者而言一直是个头疼的问题...

机器之心
06/30
0
0
从实例掌握 pytorch 进行图像分类

背景 从入门 到沉迷 再到跳出安逸选择,根本原因是在参加天池雪浪AI制造数据竞赛的时候,几乎同样的网络模型和参数,以及相似的数据预处理方式,结果得到的成绩差距之大让我无法接受,故转为...

Spytensor
08/23
0
0
如何使用 TensorFlow mobile 将 PyTorch 和 Keras 部署到移动设备

雷锋网(公众号:雷锋网)按:本文为雷锋字幕组编译的技术博客,原标题 Deploying PyTorch and Keras Models to Android with TensorFlow Mobile ,作者为 John Olafenwa 。 翻译 | 于志鹏 整理...

雷锋字幕组
07/12
0
0

没有更多内容

加载失败,请刷新页面

加载更多

负载均衡的解决方案有哪些

负载均衡器服务可满足大型组织的需求,支持所有数据中心和跨数据中心高可靠性场景。 本地负载均衡,通过附带或者未附带持久性覆盖选项,Incapsula支持各种负载均衡算法,以优化服务器之间的流...

上树的熊
43分钟前
4
0
Java实现在线打开word文档加盖印章/盖章/签名功能

前言: 我们知道,大型一点的OA办公系统都会有很多在线处理office办公文档的需求。其中有一点也基本绕不开,那就是为文档盖章或添加手写签名来保护文档,让被盖章的文档不再被编辑。 在Java中...

山里的红杏
50分钟前
5
0
js控制输入正负数,小数点后保留两位

//限制数字function clearNoNum(obj){ //修复第一个字符是小数点 的情况. if(obj.value !=''&& obj.value.substr(0,1) == '.'){ obj.value=""; } obj.value ...

一直在成长的程序猿
53分钟前
2
0
动态代理

具体场景 为了使代理类与被代理类对第三方有相同的函数,代理类与被代理类一般实现一个公共的interface,定义如下 public interface Subject { void rent(); void hello(String s)...

wuyiyi
57分钟前
2
0
时间字段

我们看看这几个数据库中(mysql、oracle和sqlserver)如何表示时间 mysql数据库:它们分别是 date、datetime、time、timestamp和year。date :“yyyy-mm-dd”格式表示的日期值 time :“hh:...

DemonsI
58分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部