文档章节

在浏览器中进行深度学习:TensorFlow.js (六)构建一个卷积网络 Convolutional Network

naughty
 naughty
发布于 2018/05/15 07:02
字数 1957
阅读 1731
收藏 54

上一篇中,我们介绍了了用TensorflowJS构建一个神经网络,然后用该模型来进行手写MINST数据的识别。和之前的基本模型比起来,模型的准确率上升的似乎不是很大。(在我的例子中,验证部分比较简单,只是一个大致的统计)甚至有些情况下,如果参数选择不当,训练效果还会更差。

卷积网络,也叫做卷积神经网络(con-volutional neural network, CNN),是一种专门用来处理具有类似网格结构的数据的神经网络。例如时间序列数据(可以认为是在时间轴上有规律地采样形成的一维网格)和图像数据(可以看作是二维的像素网格)。对于MINST手写数据来说,应用卷积网络会不会是更好的选择呢?

先上图:

代码见Codepen

该图是我应用CNN对MINST数据进行训练的结果,准确率在97%,可以说和之前的模型来比较,提高显著。要知道,要知道在获得比较高的准确率后,要提高一点都是比较困难的。那我们就简单的看看卷积网络是什么,他为什么对于手写数据的识别做的比其他模型的更好?

CNN的原理实际上是模拟了人类的视觉神经如何识别图像。每个视觉神经只负责处理不同大小的一小块画面,在不同的神经层次处理不同的信息。

卷积和核

大家可能有用过Photoshop的经验,Photoshop提供很多不同类型的滤镜来处理图像,其实那个本质上就是应用不同的核函数对图像进行卷积的结果。

卷积操作如下图所示:

“cnn deep learning gif”的图片搜索结果

左边的矩形是输入数据,也就是我们要处理的图像的张量表示。中间的矩形是核,而右边的矩形就是卷积的结果。核函数从左至右,从上到下,每次移动一个像素扫描图像,计算出卷积和的结果矩阵。

卷积的计算过程如下图:

计算就是乘法和加法,但是上图的例子计算有个错误,看你找不找得到。下面这个图计算更简单一点:

如果你能够理解上图的数学含义你就能理解,核函数其实是一个权重,对于每一个小块的图像,不同的核对不同区域的权重不一样。

如上图的两个核,左面的对于图像中间的权重为0,上面的是负向加权,而下面的正向加权。可以想像对应于普通图像,数据分布均匀,这个加权计算的结果趋近于零,对应于水平边缘,上面没有数据而下面有数据,这个加权的值就比较大,这样我们就能够检测出水平边缘。同理右边核函数对应垂直边缘。

上图就是应用垂直,水平,垂直加水平的核,对安卓小机器人图像卷积的结果。我们可以看出对应的核函数是如何识别出边缘的。

然而在学习的时候要使用什么样的核呢?我们看一下网络结构:

每一个像素都是一个特征,每一个特征是一个输入节点。每一个卷积的结果都输入到下一层的隐藏节点。核的权重就连接了输入层和隐藏层。经过0填充的输入层可以输出不同形状的卷积结果。同时可以调整扫描的步幅(stride)。

相关图片

上图中,输入为7*7,没有填充,步幅为1,输出为5*5

“cnn deep learning gif”的图片搜索结果

上图中,输入为5*5,填充1格,步幅为1,输出为5*5

“cnn deep learning gif”的图片搜索结果

上图中,输入为5*5,填充1格,步幅为2,输出为3*3

“cnn deep learning gif”的图片搜索结果

上图中,输入为2*2,填充2格,步幅为1,输出为4*4

我们可以看出来,增加填充会导致隐藏层节点数量增加,而增大步幅可以使得隐藏层的节点变少。

通过神经网络的学习,就能够确定核的权重。实际的应用,可能会有多个核,因为有许多的特征要学习。

就像我们之前看到如果要学习图像的轮廓,其实是两个不同的核的组合。

池化

池化层通常是紧跟着卷积的一层,通常是做区域的均值或者最大值操作。如下图:

如下图,池化的策略通常是取最大值或者取均值。

“cnn pooling”的图片搜索结果

池化的作用类似取样,使得下一层神经网络要处理的数据极大的缩小。减少整个网络的参数,防止出现过拟合。

 

整体结构

通常,CNN网络的由如上图所示的层次构成:

  • 输入层 Input Layer
  • 卷积层 Convolution Layer
  • 池化层 Pooling Layer
  • 全连接层 Fully Connected (Dense) Layer
  • 分类层 Softmax Classification Layer
  • 输出层 Output Layer

在了解的基本的卷积网络的概念后,我们来看看如何在TensorflowJS中实现一个CNN。

下面是模型的代码:

function cnn() {
  const model = tf.sequential();
  model.add(tf.layers.conv2d({
    inputShape: [28, 28, 1],
    kernelSize: 5,
    filters: 8,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
  model.add(tf.layers.conv2d({
    kernelSize: 5,
    filters: 16,
    strides: 1,
    activation: 'relu',
    kernelInitializer: 'varianceScaling'
  }));
  model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
  model.add(tf.layers.flatten());
  model.add(tf.layers.dense(
    {units: 10, kernelInitializer: 'varianceScaling', activation: 'softmax'}));
  return model;
}
  • tf.sequential() 创建一个连续的神经网络,自动创建输入层

  • tf.layers.conv2d 是第一层的卷积层,输入28*28*1是图像的长,高,颜色通道。核的大小是5*5,步幅是1。我们先忽略其它参数。

  • tf.layers.maxPooling2d是下一个池化层,就是以2*2的小窗口对卷积结果做池化。

  • 接着又是一个卷积和一个池化层。

  • tf.layers.flatten() 是把之前的结果打平。

  • 最后是一个softmax分类层 tf.layers.dense

类似这样一个结构

“cnn deep learning”的图片搜索结果

训练的代码如下:

const model = cnn();
const LEARNING_RATE = 0.15;
const optimizer = tf.train.sgd(LEARNING_RATE);
model.compile({
  optimizer: optimizer,
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy'],
});

async function train() {
  const BATCH_SIZE = 16;
  const TRAIN_BATCHES = 1000;

  const TEST_BATCH_SIZE = 100;
  const TEST_ITERATION_FREQUENCY = 5;

  for (let i = 0; i < TRAIN_BATCHES; i++) {
    const batch = data.nextTrainBatch(BATCH_SIZE);

    let testBatch;
    let validationData;
    // Every few batches test the accuracy of the mode.
    if (i % TEST_ITERATION_FREQUENCY === 0 && i > 0 ) {
      testBatch = data.nextTestBatch(TEST_BATCH_SIZE);
      validationData = [
        testBatch.xs.reshape([TEST_BATCH_SIZE, 28, 28, 1]), testBatch.labels
      ];
    }

    // The entire dataset doesn't fit into memory so we call fit repeatedly
    // with batches.
    const history = await model.fit(
        batch.xs.reshape([BATCH_SIZE, 28, 28, 1]), batch.labels,
        {batchSize: BATCH_SIZE, validationData, epochs: 1});

    const loss = history.history.loss[0];
    const accuracy = history.history.acc[0];

    batch.xs.dispose();
    batch.labels.dispose();
    if (testBatch != null) {
      testBatch.xs.dispose();
      testBatch.labels.dispose();
    }
    await tf.nextFrame();
  }
}

如果和之前的神经网络的训练的代码比较,这里唯一的变化就是输入数据的形状。

// For CNN
batch.xs.reshape([BATCH_SIZE, 28, 28, 1])

// For NN
batch.xs.reshape([BATCH_SIZE, 784])

这是由两个网络输入层的形状来决定的。

 

大家在选取模型的时候可以考虑CNN的优缺点。

优点:

  • 共享卷积核,对高维数据处理无压力
  • 无需手动选取特征,训练好权重,即得特征
  • 分类效果好

缺点:

  • 需要调参,需要大样本量,训练最好要用GPU
  • 物理含义不明确,随着 Convolution 的堆叠,Feature Map 变得越来越抽象,人类已经很难去理解了

CNN是非常流行的深度学习的模型,广泛用于图像相关的有关领域,从阿尔法狗到自动驾驶,到处都有他的身影。如果大家希望进一步了解,可以研习下面的文章。

参考

 

© 著作权归作者所有

共有 人打赏支持
naughty
粉丝 281
博文 63
码字总数 119365
作品 0
其它
架构师
私信 提问
加载中

评论(12)

久永
久永

引用来自“sunnyluo”的评论

引用来自“久永”的评论

还是确认了两个图是出不来的:
https://cdn-images-1.medium.com/max/1600/1*ZCjPUFrB6eHPRi4eyP6aaA.gif
https://cdn-images-1.medium.com/max/1600/1*C0EwU0aknuliOsGktK6U0g.png
同时,说明也是乱码,这就是我刚才看到的挂了的图。
但是刚刷新的时候是空白,发现不了。
@naughty
@ddev

回复@sunnyluo : 说多一点啊,光喊,也不说为啥事,最烦老婆这样了。
sunnyluo
sunnyluo

引用来自“久永”的评论

还是确认了两个图是出不来的:
https://cdn-images-1.medium.com/max/1600/1*ZCjPUFrB6eHPRi4eyP6aaA.gif
https://cdn-images-1.medium.com/max/1600/1*C0EwU0aknuliOsGktK6U0g.png
同时,说明也是乱码,这就是我刚才看到的挂了的图。
但是刚刷新的时候是空白,发现不了。
@naughty
@ddev
sunnyluo
sunnyluo

引用来自“久永”的评论

图挂了几个,麻烦修复下,在线等。。。
deng
久永
久永

引用来自“久永”的评论

图挂了几个,麻烦修复下,在线等。。。

引用来自“naughty”的评论

我这里看着是好的呀

引用来自“赵传喜”的评论

我这边还是挂了几张图
和我贴的图是否一致?打开我贴的图片地址看看。我看过,这些地址和能显示的图片地址来源站点不一样。
久永
久永

引用来自“naughty”的评论

引用来自“久永”的评论

还是确认了两个图是出不来的:
https://cdn-images-1.medium.com/max/1600/1*ZCjPUFrB6eHPRi4eyP6aaA.gif
https://cdn-images-1.medium.com/max/1600/1*C0EwU0aknuliOsGktK6U0g.png
同时,说明也是乱码,这就是我刚才看到的挂了的图。
但是刚刷新的时候是空白,发现不了。
@naughty
我换了个浏览器,也没有问题。你清了缓存再试试吧。

回复@naughty : 我贴出来的那个图片地址你能访问吗?我换了两个浏览器确认出不来才二次报障的。难道是线路或者运营商问题?有没有其它同学看到,能访问显示我贴出来的文章里的图片地址吗?
开源中国-首席营养师
动图不错,把卷积和池化表达的很清楚
naughty
naughty

引用来自“久永”的评论

还是确认了两个图是出不来的:
https://cdn-images-1.medium.com/max/1600/1*ZCjPUFrB6eHPRi4eyP6aaA.gif
https://cdn-images-1.medium.com/max/1600/1*C0EwU0aknuliOsGktK6U0g.png
同时,说明也是乱码,这就是我刚才看到的挂了的图。
但是刚刷新的时候是空白,发现不了。
@naughty
我换了个浏览器,也没有问题。你清了缓存再试试吧。
久永
久永
还是确认了两个图是出不来的:
https://cdn-images-1.medium.com/max/1600/1*ZCjPUFrB6eHPRi4eyP6aaA.gif
https://cdn-images-1.medium.com/max/1600/1*C0EwU0aknuliOsGktK6U0g.png
同时,说明也是乱码,这就是我刚才看到的挂了的图。
但是刚刷新的时候是空白,发现不了。
@naughty
赵传喜
赵传喜

引用来自“久永”的评论

图挂了几个,麻烦修复下,在线等。。。

引用来自“naughty”的评论

我这里看着是好的呀
我这边还是挂了几张图
久永
久永

引用来自“naughty”的评论

引用来自“久永”的评论

图挂了几个,麻烦修复下,在线等。。。
我这里看着是好的呀

回复@naughty : 重新刷新好像好了,刚才图片处乱码。
用浏览器训练Tensorflow.js模型的18个技巧(上)

在移植现有模型(除tensorflow.js)进行物体检测、人脸检测、人脸识别后,我发现一些模型不能以最佳性能发挥。而tensorflow.js在浏览器中表现相当不错,如果你想见证浏览器内部机器学习的潜力...

【方向】
2018/10/13
0
0
五大经典卷积神经网络介绍:LeNet / AlexNet / GoogLeNet / VGGNet/ ResNet

欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习、深度学习的知识! LeNet / AlexNet / GoogLeNet / VGGNet/ ResNet 前言:这个系列文章将会从经典的...

磐石001
2018/04/03
0
0
深度学习(七十)darknet 实现编写mobilenet源码

版权声明:本文为博主原创文章,欢迎转载,转载请注明原文地址、作者信息。 https://blog.csdn.net/hjimce/article/details/76175802 一、添加一个新的网络层 (1)parse.c文件中函数stringtol...

hjimce
2017/07/27
0
0
前端工程师掌握这18招,就能在浏览器里玩转深度学习

参加 2018 AI开发者大会,请点击 ↑↑↑ 作者 | Vincent Mühler 译者 | 刘旭坤 整理 | Jane 出品 | AI科技大本营 【导读】TensorFlow.js 的发布可以说是 JS 社区开发者的福音!但是在浏览器...

AI科技大本营
2018/10/20
0
0
深度学习在图像超分辨率重建中的应用

超分辨率技术(Super-Resolution)是指从观测到的低分辨率图像重建出相应的高分辨率图像,在监控设备、卫星图像和医学影像等领域都有重要的应用价值。SR可分为两类:从多张低分辨率图像重建出...

taigw
2017/03/16
0
0

没有更多内容

加载失败,请刷新页面

加载更多

阿里云vpc、快照、镜像、重置密码_重启_关机、磁盘扩容

VPC 专有网络VPC(Virtual Private Cloud)是用户基于阿里云创建的自定义私有网络, 不同的专有网络之间二层逻辑隔离,用户可以在自己创建的专有网络内创建和管理云产品实例,比如ECS、负载均...

李超小牛子
15分钟前
0
0
阿里高级技术专家:研发效能的追求永无止境

背景 大约在5年前,也就是2013年我刚加入阿里的时候,那个时候 DevOps 的风刚吹起来没多久,有家公司宣称能够一天发布几十上百次,这意味着相比传统软件公司几周一次的发布来说,他们响应商业...

阿里云官方博客
44分钟前
1
0
Android 的 ViewModel 机制源码解析

Android ViewModel 的好处是会随 Activity 销毁调用它的 clear() 方法。 我们分析一下它是怎么做到的。 1. 例子使用: a、 创建类 TestMvvmViewModel 继承 ViewModel,重写 onCleared() ,把...

亭子happy
54分钟前
2
0
WEB 开发总结

事务处理 事务的4个基本特征 1.Atomic(原子性),事务中包含的操作被看做是一个整体的业务单元,这个业务单元中的操作要么全部成功,要么全部失败,不会出现部分成功,部分失败的场景。 2....

北漂的我
今天
3
0
thinkphp5 利用七牛云 将amr格式语音文件转为mp3

$card_id 是我的本地的文件 将问价名字的后缀名去掉注意access_token的有效期public function ceshi1($card_id){ $mediaid = substr($card_id, 0, -4); $accessKey = ...

小小小壮
今天
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部