文档章节

卷积神经网络训练模拟量化实践

Ldpe2G
 Ldpe2G
发布于 01/13 13:26
字数 2106
阅读 107
收藏 0

前言

        深度学习在移动端的应用是越来越广泛,由于移动端的运算力与服务器相比还是有差距,所以在移动端

部署深度学习模型的难点就在于如何保证模型效果的同时,运行效率也有保证。在实验阶段对于模型结构可以

选择大模型,因为该阶段主要是为了验证方法的有效性。在验证完了之后,开始着手部署到移动端,这时候就

要精简模型的结构了,一般是对训好的大模型进行剪枝,或者参考现有的比如MobileNetV2和ShuffleNetV2

等轻量级的网络重新设计自己的网络模块。而算法层面的优化除了剪枝还有量化,量化就是把浮点数(高精度)

表示的权值和激活值用更低精度的整数来近似表示。低精度的优点有,相比于高精度算术运算,其在单位时间

内能处理更多的数据,而且权值量化之后模型的存储空间能进一步的减少等等[1]。对训练好的网络做量化,

在实践中尝试过TensorRT[5][8]的后训练量化算法,效果还不错。但是如果能在训练过程中去模拟量化的过程,

让网络学习去修正量化带来的误差,那么得到的量化参数应该是更准确的,而且在实际量化推断中模型的性能

损失应该能更小。而本文的内容就是介绍论文[3][4]和复现其过程中的一些细节。

    按照惯例,先给出本文实验的代码:TrainQuantization

训练模拟量化

方法介绍

    首先来看下量化的具体定义,对于量化激活值到有符号8bit整数,论文中给出的定义如下:

 

    公式中的三角形表示量化的缩放因子,x表示浮点数激活值,首先通过除以缩放因子然后最近邻取整,

然后把范围限制到一个区间内,比如量化到有符号8bit,那么范围就是 [-128, 127]。而对于权值还有一

个小的技巧,就是量化到[-127, 127]:

具体为什么这么做,论文中说了是为了实现上的优化,具体解释可以看论文[3]附录B ARM NEON details

这一小节。

    而训练量化说白了就是在forward阶段去模拟量化这个过程,本质就是把权值和激活值量化到8bit再反

量化回有误差的32bit,所以训练还是浮点,backward阶段是对模拟量化之后权值的求梯度,然后用这个

梯度去更新量化前的权值。然后在下个batch继续这个过程,通过这样子能够让网络学会去修正量化带来的

误差。

    上面给这个示意图就很直观的表示了模拟量化的过程,比如上面那条线表示的是量化前的范围[rmin, rmax],

然后下面那条线表示的就是量化之后的范围[-128, 127],比如现在要进行模拟量化的forward,先看上面那

条线从左到右数第4个圆点,通过除以缩放因子之后就会映射124到125之间的一个浮点数,然后通过最近邻

取整就取到了125,再通过乘以缩放因子返回上面第五个圆点,最后就用这个有误差的数替换原来的去forward。

forward阶段的模拟量化用公式表示如下:

backward阶段求梯度的公式表示如下:

    对于缩放因子的计算,权值和激活值的不一样,权值的计算方法是每次forward直接对权值求绝对值取

最大值,然后缩放因子 weight scale = max(abs(weight)) / 127。然后对于激活值,稍微有些不一样,

激活值的量化范围不是简单的计算最大值,而是通过EMA(exponential moving averages)在训练中

去统计这个量化范围,更新公式如下:

moving_max = moving_max * momenta + max(abs(activation)) * (1- momenta)

公式中的activation表示每个batch的激活值,而论文中说momenta取接近1的数就行了,在实验中

我是取0.95。然后缩放因子 activation scale = moving_max /128。

实现细节

    在实现过程中我没有按照论文的方法量化到无符号8bit,而是有符号8bit,第一是因为无符号8bit量化

需要引入额外的零点,增加复杂性,其次在实际应用过程中都是量化到有符号8bit。然后论文中提到,

对于权值的量化分通道进行求缩放因子,然后对于激活值的量化整体求一个缩放因子,这样效果最好。

在实践中发现有些任务权值不分通道量化效果也不错,这个还是看具体任务吧,不过本文给的实验代码

是没分的。

    然后对于卷积层之后带batchnorm的网络,因为一般在实际使用阶段,为了优化速度,batchnorm

的参数都会提前融合进卷积层的参数中,所以训练模拟量化的过程也要按照这个流程。首先把batchnorm

的参数与卷积层的参数融合,然后再对这个参数做量化。以下两张图片分别表示的是训练过程与实际应用

过程中对batchnorm层处理的区别:

     

    对于如何融合batchnorm参数进卷积层参数,看以下公式:

公式中的,W和b分别表示卷积层的权值与偏置,x和y分别为卷积层的输入与输出,则根据bn的计算

公式,可以推出融合了batchnorm参数之后的权值与偏置,Wmerge和bmerge。

    在实验中我其实是简化了融合batchnorm的流程,要是完全按照论文中的实现要复杂很多,

而且是基于已经训好的网络去做模拟量化实验的,不基于预训练模型训不起来,可能还有坑要踩。

而且在模拟量化训练过程中batchnorm层参数固定,融合batchnorm参数也是用已经训好的移动

均值和方差,而不是用每个batch的均值和方差。

    具体实现的时候就是按照论文中的这个模拟量化卷积层示例图去写训练网络结构的。

实验结果

    用VGG在Cifar10上做了下实验,效果还可以,因为是为了验证量化训练的有效性,所以训

Cifar10的时候没怎么调过参,数据增强也没做,训出来的模型精确度最高只有0.877,比最好的

结果0.93差不少,然后模拟量化是基于这个0.877的模型去做的,可以得到与普通训练精确度基本

一样的模型,可能是这个分类任务比较简单。然后得到训好的模型与每层的量化因子之后,就可以

模拟真实的量化推断过程,不过因为MXNet的卷积层不支持整型运算,所以模拟的过程也是用浮点

来模拟,具体实现细节可见示例代码。

结束语

    以上内容是根据最近的一些工作实践总结得到的一篇博客,对于论文的实现很多地方都是我自己

个人的理解,如果有读者发现哪里有误或者有疑问,也请指出,大家互相交流学习:)。

参考资料

[1] 8-Bit Quantization and TensorFlow Lite: Speeding up mobile inference with low precision

[2] Building a quantization paradigm from first principles

[3] Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference

[4] Quantizing deep convolutional networks for efficient inference: A whitepaper

[5] 8-bit Inference with TensorRT

[6] TensorRT(5)-INT8校准原理

[7] caffe-int8-convert-tool.py

© 著作权归作者所有

共有 人打赏支持
Ldpe2G
粉丝 19
博文 20
码字总数 33125
作品 0
广州
程序员
私信 提问
国家“千人”王中风教授:如何满足不同应用场景下深度神经网络模型算力和能效需求

基于神经网络的深度学习算法已经在计算机视觉、自然语言处理等领域大放异彩。然而,诸如 VGG、ResNet 和 Xception 等深度模型在取得优越性能的同时往往伴随着极高的存储空间需求和计算复杂度...

技术小能手
2017/12/25
0
0
谷歌开发者:看可口可乐公司是怎么玩转 TensorFlow 的?

在这篇客座文章中,可口可乐公司的 Patrick Brandt 将向我们介绍他们如何使用 AI 和 TensorFlow 实现无缝式购买凭证。 可口可乐的核心忠诚度计划于 2006 年以 MyCokeRewards.com 形式启动。“...

磐石001
2017/10/11
0
0
XNOR-Net:二值化卷积神经网络

Index Introduction Related Works Binary Neural Networks XNOR-Net Conclusion Introduction 神经网络模型的压缩是一个很有前景的方向。由于神经网络需要较大的计算量,目前来说,我们通常...

Efackw13
2017/08/07
0
0
一文带你读懂 WaveNet:谷歌助手的声音合成器

本文为 AI 研习社编译的技术博客,原标题 : WaveNet: Google Assistant’s Voice Synthesizer 作者 | Janvijay Singh 翻译 | 酱番梨、王立鱼、莫青悠、Disillusion 校对、整理 | 菠萝妹 原文...

雷锋字幕组
01/18
0
0
使用拓扑数据分析理解卷积神经网络模型的工作过程

1.简介 神经网络在各种数据方面处理上已经取得了很大的成功,包括图像、文本、时间序列等。然而,学术界或工业界都面临的一个问题是,不能以任何细节来理解其工作的过程,只能通过实验来检测...

【方向】
2018/07/01
0
0

没有更多内容

加载失败,请刷新页面

加载更多

大数据教程(11.9)hive操作基础知识

上一篇博客分享了hive的简介和初体验,本节博主将继续分享一些hive的操作的基础知识。 DDL操作 (1)创建表 #建表语法CREATE [EXTERNAL] TABLE [IF NOT EXISTS] table_name [(col_name ...

em_aaron
57分钟前
0
0
OSChina 周四乱弹 —— 我家猫真会后空翻

Osc乱弹歌单(2019)请戳(这里) 【今日歌曲】 @我没有抓狂 :#今天听这个# 我艇牛逼,百听不厌,太好听辣 分享 Led Zeppelin 的歌曲《Stairway To Heaven》 《Stairway To Heaven》- Led Z...

小小编辑
今天
1
0
node调用dll

先安装python2.7 安装node-gyp cnpm install node-gyp -g 新建一个Electron-vue项目(案例用Electron-vue) vue init simulatedgreg/electron-vue my-project 安装electron-rebuild cnpm ins......

Chason-洪
今天
3
0
scala学习(一)

学习Spark之前需要学习Scala。 参考学习的书籍:快学Scala

柠檬果过
今天
3
0
通俗易懂解释网络工程中的技术,如STP,HSRP等

导读 在面试时,比如被问到HSRP的主备切换时间时多久,STP几个状态的停留时间,自己知道有这些东西,但在工作中不会经常用到,就老是记不住,觉得可能还是自己基础不够牢固,知识掌握不够全面...

问题终结者
昨天
4
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部