文档章节

PyTorch快速入门教程三(神经网络)

earnpls
 earnpls
发布于 2017/07/02 09:11
字数 569
阅读 212
收藏 0

Neural Network

其实简单的神经网络说起来很简单,先看下图:

Neural Network

上图即可看出,其实每一层网络所做的就是 y=W×X+b,只不过W的维数由X和输出维书决定,比如X是10维向量,想要输出的维数,也就是中间层的神经元个数为20,那么W的维数就是20x10,b的维数就是20x1,这样输出的y的维数就为20。

中间层的维数可以自己设计,而最后一层输出的维数就是你的分类数目,比如我们等会儿要做的MNIST数据集是10个数字的分类,那么最后输出层的神经元就为10。

Code

有了前面两节的经验,这一节的代码就很简单了,数据的导入和之前一样

定义模型

class Neuralnetwork(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Neuralnetwork, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

model = Neuralnetwork(28*28, 300, 100, 10)
if torch.cuda.is_available():
    model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

上面定义了三层神经网络,输入是28x28,因为图片大小是28x28,中间两个隐藏层大小分别是300和100,最后是个10分类问题,所以输出层为10.

训练过程与之前完全一样,我就不再重复了,可以直接去github参看完整的代码

这是50次之后的输出结果,可以和上一节logistic回归比较一下

PyTorch神经网络

可以发现准确率大大提高,其实logistic回归可以看成简单的一层网络,从这里我们就可以看出为什么多层网络比单层网络的效果要好,这也是为什么深度学习要叫深度的原因。

本文转载自:https://ptorch.com/news/7.html

earnpls
粉丝 6
博文 26
码字总数 74
作品 0
昌平
程序员
私信 提问
PyTorch 官方中文教程包含 60 分钟快速入门教程,强化教程

PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 d 的人工智能小组开发,不仅能够 实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很...

磐创AI_聊天机器人
08/14
0
0
PyTorch:60分钟入门学习

最近在学习PyTorch这个深度学习框架,在这里做一下整理分享给大家,有什么写的不对或者不好的地方,还请大侠们见谅啦~~~ 写在前面 本文就是主要是对PyTorch的安装,以及入门学习做了记录,...

与阳光共进早餐
2018/01/15
0
0
PyTorch 60 分钟入门教程

PyTorch 60 分钟入门教程:PyTorch 深度学习官方入门中文教程 http://pytorchchina.com/2018/06/25/what-is-pytorch/ PyTorch 60 分钟入门教程:自动微分 http://pytorchchina.com/2018/12/...

不知道自己是谁
2018/12/25
0
0
库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

选自 Github,作者:bharathgs,机器之心编译。 机器之心发现了一份极棒的 PyTorch 资源列表,该列表包含了与 PyTorch 相关的众多库、教程与示例、论文实现以及其他资源。在本文中,机器之心...

机器之心
2018/10/22
0
0
PyTorch 你想知道的都在这里

本文转载地址,并进行了加工。本文适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Based CNN、A3C、WGAN、BERT等等。所有代码均按照所属技术领域分...

readilen
2018/10/20
0
0

没有更多内容

加载失败,请刷新页面

加载更多

Mybatis Plus删除

/** @author beth @data 2019-10-17 00:30 */ @RunWith(SpringRunner.class) @SpringBootTest public class DeleteTest { @Autowired private UserInfoMapper userInfoMapper; /** 根据id删除......

一个yuanbeth
今天
4
0
总结

一、设计模式 简单工厂:一个简单而且比较杂的工厂,可以创建任何对象给你 复杂工厂:先创建一种基础类型的工厂接口,然后各自集成实现这个接口,但是每个工厂都是这个基础类的扩展分类,spr...

BobwithB
今天
4
0
java内存模型

前言 Java作为一种面向对象的,跨平台语言,其对象、内存等一直是比较难的知识点。而且很多概念的名称看起来又那么相似,很多人会傻傻分不清楚。比如本文我们要讨论的JVM内存结构、Java内存模...

ls_cherish
今天
4
0
友元函数强制转换

友元函数强制转换 p522

天王盖地虎626
昨天
5
0
js中实现页面跳转(返回前一页、后一页)

本文转载于:专业的前端网站➸js中实现页面跳转(返回前一页、后一页) 一:JS 重载页面,本地刷新,返回上一页 复制代码代码如下: <a href="javascript:history.go(-1)">返回上一页</a> <a h...

前端老手
昨天
5
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部