文档章节

PyTorch快速入门教程五(rnn)

earnpls
 earnpls
发布于 2017/07/02 09:16
字数 1268
阅读 495
收藏 0

上一讲讲了cnn以及如何使用pytorch实现简单的多层卷积神经网络,下面我们将进入rnn,关于rnn将分成三个部分,

  1. 介绍rnn的基本结构以及在pytorch里面api的各个参数所表示的含义,
  2. 介绍rnn如何在MNIST数据集上做分类,
  3. 涉及一点点自然语言处理的东西。

RNN

首先介绍一下什么是rnnrnn特别擅长处理序列类型的数据,因为他是一个循环的结构

pytorch Rnn

一个序列的数据依次进入网络A,网络A循环的往后传递。

这就是RNN的基本结构类型。而最早的RNN模型,序列依次进入网络中,之前进入序列的数据会保存信息而对后面的数据产生影响,所以RNN有着记忆的特性,而同时越前面的数据进入序列的时间越早,所以对后面的数据的影响也就越弱,简而言之就是一个数据会更大程度受到其临近数据的影响。但是我们很有可能需要更长时间之前的信息,而这个能力传统的RNN特别弱,于是有了LSTM这个变体。

LSTM

pytorch LSTM

这就是LSTM的模型结构,也是一个向后传递的链式模型,而现在广泛使用的RNN其实就是LSTM,序列中每个数据传入LSTM可以得到两个输出,而这两个输出和序列中下一个数据一起又作为传入LSTM的输入,然后不断地循环向后,直到序列结束。

下面结合pytorch一步一步来看数据传入LSTM是怎么运算的

首先需要定义好LSTM网络,需要nn.LSTM(),首先介绍一下这个函数里面的参数

  1. input_size 表示的是输入的数据维数

  2. hidden_size 表示的是输出维数

  3. num_layers 表示堆叠几层的LSTM,默认是1

  4. bias True 或者 False,决定是否使用bias

  5. batch_first True 或者 False,因为nn.lstm()接受的数据输入是(序列长度,batch,输入维数),这和我们cnn输入的方式不太一致,所以使用batch_first,我们可以将输入变成(batch,序列长度,输入维数)

  6. dropout 表示除了最后一层之外都引入一个dropout

  7. bidirectional 表示双向LSTM,也就是序列从左往右算一次,从右往左又算一次,这样就可以两倍的输出

是网络的输出维数,比如M,因为输出的维度是M,权重w的维数就是(M, M)和(M, K),b的维数就是(M, 1)和(M, 1),最后经过sigmoid激活函数,得到的f的维数是(M, 1)。 pytorch bidirectional

对于第一个数据,需要定义初始的h_0和c_0,所以nn.lstm()的输入Inputs:input, (h_0, c_0),表示输入的数据以及h_0和c_0,这个可以自己定义,如果不定义,默认就是0

pytorch

第二步也是差不多的操作,只不多是另外两个权重加上不同的激活函数,一个使用的是sigmoid,一个使用的是tanh,得到的输出i_t和\tilde{C}_t都是(M, 1)。

pytorch

接着这个乘法是矩阵每个位置对应相乘,然后将两个矩阵加起来,得到的输出C_t是(M, 1)。

pytorch 最后一步得到的o_t也是(M, 1),然后C_t经过激活函数tanh,再和o_t每个位置相乘,得到的输出h_t也是(M, 1)。

最后得到的输出就是h_t和C_t,维数分别都是(M, 1),而输入x_t 维数都是(K, 1)。

lstm = nn.LSTM(10, 30, batch_first=True)

可以通过这样定义一个一层的LSTM输入是10,输出是30

lstm.weight_hh_l0.size()
lstm.weight_ih_l0.size()
lstm.bias_hh_l0.size()
lstm.bias__ih_l0.size()

可以分别得到权重的维数,注意之前我们定义的4个weights被整合到了一起,比如这个lstm,输入是10维,输出是30维,相对应的weight就是30x10,这样的权重有4个,然后pytorch将这4个组合在了一起,方便表示,也就是lstm.weight_ih_l0,所以它的维数就是120x10

我们定义一个输入

x = Variable(torch.randn((50, 100, 10)))
h0 = Variable(torch.randn(1, 50, 30))
c0 = Variable(torch.randn(1, 50 ,30))

x的三个数字分别表示batch_size为50,序列长度为100,每个数据维数为10

h0的第二个参数表示batch_size为50,输出维数为30,第一个参数取决于网络层数和是否是双向的,如果双向需要乘2,如果是多层,就需要乘以网络层数

c0的三个参数和h0是一致的

out, (h_out, c_out) = lstm(x, (h0, c0))

这样就可以得到网络的输出了,和上面讲的一致,另外如果不传入h0和c0,默认的会传入相同维数的0矩阵

这就是我们如何在pytorch上使用RNN的基本操作了,了解完最基本的参数我们才能够使用其来做应用。

本文转载自:https://ptorch.com/news/9.html

earnpls
粉丝 6
博文 26
码字总数 74
作品 0
昌平
程序员
私信 提问
PyTorch 官方中文教程包含 60 分钟快速入门教程,强化教程

PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 d 的人工智能小组开发,不仅能够 实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很...

磐创AI_聊天机器人
08/14
0
0
PyTorch 你想知道的都在这里

本文转载地址,并进行了加工。本文适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Based CNN、A3C、WGAN、BERT等等。所有代码均按照所属技术领域分...

readilen
2018/10/20
0
0
库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

选自 Github,作者:bharathgs,机器之心编译。 机器之心发现了一份极棒的 PyTorch 资源列表,该列表包含了与 PyTorch 相关的众多库、教程与示例、论文实现以及其他资源。在本文中,机器之心...

机器之心
2018/10/22
0
0
PyTorch 1.2 中文文档校对活动 | ApacheCN

整体进度:https://github.com/apachecn/pytorch-doc-zh/issues/422 贡献指南:https://github.com/apachecn/pytorch-doc-zh/blob/master/CONTRIBUTING.md 项目仓库:https://github.com/ap......

ApacheCN_飞龙
09/25
10
0
PyTorch 60 分钟入门教程

PyTorch 60 分钟入门教程:PyTorch 深度学习官方入门中文教程 http://pytorchchina.com/2018/06/25/what-is-pytorch/ PyTorch 60 分钟入门教程:自动微分 http://pytorchchina.com/2018/12/...

不知道自己是谁
2018/12/25
0
0

没有更多内容

加载失败,请刷新页面

加载更多

iptables删除命令中的相关问题

最近在做一个中间件的配置工作,在配置iptables的时候,当用户想删除EIP(即释放当前连接),发现使用iptables的相关命令会提示错误。iptables: Bad rule (does a matching rule exist in t...

xiangyunyan
39分钟前
2
0
IT兄弟连 HTML5教程 HTML5表单 新增的表单属性1

HTML5 Input表单为<form>和<input>标签添加了几个新属性,属性如表1。 1 autocomplete属性 autocomplete属性规定form或input域应该拥有自动完成功能,当用户在自动完成域中开始输入时,浏览器...

老码农的一亩三分地
今天
7
0
OSChina 周五乱弹 —— 葛优理论+1

Osc乱弹歌单(2019)请戳(这里) 【今日歌曲】 @这次装个文艺青年吧 :#今日歌曲推荐# 分享米津玄師的单曲《LOSER》: mv中的舞蹈诡异却又美丽,如此随性怕是难再跳出第二次…… 《LOSER》-...

小小编辑
今天
1K
20
nginx学习笔记

中间件位于客户机/ 服务器的操作系统之上,管理计算机资源和网络通讯。 是连接两个独立应用程序或独立系统的软件。 web请求通过中间件可以直接调用操作系统,也可以经过中间件把请求分发到多...

码农实战
今天
5
0
Spring Security 实战干货:玩转自定义登录

1. 前言 前面的关于 Spring Security 相关的文章只是一个预热。为了接下来更好的实战,如果你错过了请从 Spring Security 实战系列 开始。安全访问的第一步就是认证(Authentication),认证...

码农小胖哥
今天
16
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部