文档章节

PyTorch快速入门教程四(cnn:卷积神经网络 )

earnpls
 earnpls
发布于 2017/07/02 09:13
字数 803
阅读 139
收藏 0

以前的教程中我们已经完成了基础部分,接下来进入深度学习部分,第一个要讲的是cnn,也就是卷积神经网络: cnn,也就是卷积神经网络

数据集仍然是使用MNIST手写字体,和之前一样做同样的预处理。

model

# 定义 Convolution Network 模型
class Cnn(nn.Module):
    def __init__(self, in_dim, n_class):
        super(Cnn, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, 6, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5, stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
        )

        self.fc = nn.Sequential(
            nn.Linear(400, 120),
            nn.Linear(120, 84),
            nn.Linear(84, n_class)
        )

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

model = Cnn(1, 10)  # 图片大小是28x28
use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

以上就是网络的模型的部分了。和之前比主要增加了这些不一样的部分

  1. 1、nn.Sequential()

这个表示将一个有序的模块写在一起,也就相当于将神经网络的层按顺序放在一起,这样可以方便结构显示

  1. 2、nn.Conv2d()

这个是卷积层,里面常用的参数有四个,in_channelsout_channelskernel_sizestridepadding

in_channels表示的是输入卷积层的图片厚度
out_channels表示的是要输出的厚度
kernel_size表示的是卷积核的大小,可以用一个数字表示长宽相等的卷积核,比如kernel_size=3,也可以用不同的数字表示长宽不同的卷积核,比如kernel_size=(3, 2) stride表示卷积核滑动的步长

padding表示的是在图片周围填充0的多少,padding=0表示不填充,padding=1四周都填充1维

  1. 3、nn.ReLU()
    这个表示使用ReLU激活函数,里面有一个参数inplace,默认设置为False,表示新创建一个对象对其修改,也可以设置为True,表示直接对这个对象进行修改
  2. 4、nn.MaxPool2d() 这个是最大池化层,当然也有平均池化层,里面的参数有kernel_sizestridepadding

kernel_size表示池化的窗口大小,和卷积层里面的 kernel_size是一样的

stride也和卷积层里面一样,需要自己设置滑动步长

padding也和卷积层里面的参数是一样的,默认是0

模型需要传入的参数是输入的图片维数以及输出的种类数

train

训练的过程是一样的,只是输入图片不再需要展开

这是训练20个epoch的结果,当然你也可以增加训练次数,修改里面的参数达到更好的效果,可以参考一下Lenet的网络结构,自己重新写一写

pytorch-train

大体上简单的卷积网络就是这么构建的,当然现在也有很多复杂的网络,比如vgg,inceptionv1-v4,resnet以及修正的inception-resnet,这些网络都是深层的卷积网络,有兴趣的同学可以去看看pytorch的官方代码实现,或者去github上搜索相应的网络。

本文转载自:https://ptorch.com/news/8.html

earnpls
粉丝 6
博文 26
码字总数 74
作品 0
昌平
程序员
私信 提问
库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

选自 Github,作者:bharathgs,机器之心编译。 机器之心发现了一份极棒的 PyTorch 资源列表,该列表包含了与 PyTorch 相关的众多库、教程与示例、论文实现以及其他资源。在本文中,机器之心...

机器之心
2018/10/22
0
0
PyTorch 你想知道的都在这里

本文转载地址,并进行了加工。本文适用于深度学习新手的“入门指导系列”,也有适用于老司机的论文代码实现,包括 Attention Based CNN、A3C、WGAN、BERT等等。所有代码均按照所属技术领域分...

readilen
2018/10/20
0
0
10分钟快速入门PyTorch (0)

之前有很多小伙伴私信我说文章思想能看懂,但是pytorch的部分因为没有看过pytorch教程所以一脸懵逼。对此我也表示很无奈,既然大家不愿意去官网看教程,那么我就将我学习pytorch的经验写出来...

SherlockLiao
2017/05/11
0
0
MNIST数据集深度学习实践汇总

Why MNIST MNIST数据集对深度学习初学者来说应该是最友好的数据集了: 拿来即用,你只需要专注于模型搭建就好(数据处理真的很费时间); 数据集不大,很适合普通玩家,一般的PC都能跑的动,...

Awesome_Tang
05/02
0
0
PyTorch 官方中文教程包含 60 分钟快速入门教程,强化教程

PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。它主要由 d 的人工智能小组开发,不仅能够 实现强大的 GPU 加速,同时还支持动态神经网络,这一点是现在很...

磐创AI_聊天机器人
08/14
0
0

没有更多内容

加载失败,请刷新页面

加载更多

代理模式之JDK动态代理 — “JDK Dynamic Proxy“

动态代理的原理是什么? 所谓的动态代理,他是一个代理机制,代理机制可以看作是对调用目标的一个包装,这样我们对目标代码的调用不是直接发生的,而是通过代理完成,通过代理可以有效的让调...

code-ortaerc
29分钟前
4
0
学习记录(day05-标签操作、属性绑定、语句控制、数据绑定、事件绑定、案例用户登录)

[TOC] 1.1.1标签操作v-text&v-html v-text:会把data中绑定的数据值原样输出。 v-html:会把data中值输出,且会自动解析html代码 <!--可以将指定的内容显示到标签体中--><标签 v-text=""></......

庭前云落
今天
7
0
VMware vSphere的两种RDM磁盘

在VMware vSphere vCenter中创建虚拟机时,可以添加一种叫RDM的磁盘。 RDM - Raw Device Mapping,原始设备映射,那么,RDM磁盘是不是就可以称作为“原始设备映射磁盘”呢?这也是一种可以热...

大别阿郎
今天
10
0
【AngularJS学习笔记】02 小杂烩及学习总结

本文转载于:专业的前端网站☞【AngularJS学习笔记】02 小杂烩及学习总结 表格示例 <div ng-app="myApp" ng-controller="customersCtrl"> <table> <tr ng-repeat="x in names | orderBy ......

前端老手
昨天
14
0
Linux 内核的五大创新

在科技行业,创新这个词几乎和革命一样到处泛滥,所以很难将那些夸张的东西与真正令人振奋的东西区分开来。Linux内核被称为创新,但它又被称为现代计算中最大的奇迹,一个微观世界中的庞然大...

阮鹏
昨天
18
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部