文档章节

[DNN] 尝试理解深度神经网络的Large-batch魔咒

Airship
 Airship
发布于 2017/06/13 09:33
字数 1585
阅读 15
收藏 0
点赞 0
评论 0

[DNN] 尝试理解深度神经网络的Large-batch魔咒

[DNN] 尝试理解深度神经网络的Large-batch魔咒

张泰源张泰源

1 天前

最近贵司的“一小时训练ImageNet”论文在国内外各种刷屏(https://research.fb.com/publications/imagenet1kin1h/),看了一下,确实非常实用主义的文章,介绍很多有用的trick,包括系统实现上的很多坑都覆盖到了。其中谈到加速训练的难点之一是:需要用到更大的mini-batch size,而这通常会降低准确率,所以他们通过linear-scaling learning rate解决了这个问题。看到这里我对于这个难点产生了疑问——batch size越大,不应该训练的方差越小,随机性越小,从而能够更准确地拟合数据集么?

从一个对深度学习接触不多的人(比如我)的角度,这点确实有点反直觉。当batch-size不断增大,直到跟数据集一样大的时候,SGD (Stochastic Gradient Descent)就变成了最朴素的GD,一次梯度更新会扫描一遍所有的数据来算梯度。看教科书和在CMU上Machine Learning的时候被灌输的理念都是:SGD相对于GD,或者小的batch相对于大的batch,有助于更快收敛,但是准确度会下降。为什么到了深度神经网络这里就反过来了呢?

我的第一猜想是:神经网络的函数空间非常non-convex,所以mini-batch越小就越容易不断跳出local minima,寻找更好的最小值。但是自己马上感觉这个猜想有很多漏洞,不能自圆其说,所以我去查证了一下其他人的分析——Facebook的论文原文有提及过这个问题,以及ICLR 2017上也有一篇论文针对这个问题分析了一下。有趣的事,两篇文章的观点并不相同,Facebook的论文还轻踩了对方一下说“根据我们跑的实验这事不是你们说的那样儿的”。由于ICLR 的论文先出来,我们先看看它怎么说:

ICLR 2017: ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA https://openreview.net/pdf?id=H1oyRlYgg 

这篇文章主要研究了“为什么Large batch size会让错误率提高”的问题,提出了四个可能的猜想:

(i) LB methods over-fit the model;

(ii) LB methods are attracted to saddle points;

(iii) LB methods lack the explorative properties of SB methods and tend to zoom-in on the minimizer closest to the initial point;

(iv) SB and LB methods converge to qualitatively different minimizers with differing generalization properties.

然后通过实验,得出了支持(iii)和(iv)的证据。也就是说,主要是两点原因:

1) LB (Large-Batch) 方法探索性太差,容易在离起始点附近很近的地方停下来

2) LB和SB由于训练方式上的差异,最终会导致它们最终收敛的点具有一些数学属性的差异

#1 很好理解,跟我前面的猜想有点类似。这里着重谈谈#2 - 文章谈到,LB方法会收敛到Sharp-minimum,而SB方法会收敛到Flat-minimum。这两种minimum的差别如图所示:

在同样的Bias下,明显Flat的曲线比Sharp的曲线更加接近真实情况,所以Flat Minimum的generalization performance更好。

然后,基于这个假设,他们给出的解决方案是:先用SB方法训练几个epoch,让它先探索一下,找到一个比较Flat的区域,再用LB方法慢慢收敛到正确的地方。论文给出了performance vs. # of epoch trained with SB,但个人感觉不是很有说服力。。。

Facebook: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

再回到Facebook这篇文章,作者认为,LB之所以不work,不是因为上面那篇论文提到的泛化能力的问题,而主要是一个optimization issue(我的理解是优化过程/优化算法的问题)。文章没有给出理论分析, 而是直接给出了实验数据:首先,这篇论文是基于“Linear Scaling Learning Rate”来做的,简单来说,假如说原来batch size是256,learning rate是0.1;那么当把batch size设成8192的时候,learning rate就设成3.2 。batch size翻多少倍,learning rate就翻多少倍。然后,基于这个方法,论文作者发现,如果用LB方法,刚开始就用很大的learning rate的话,效果其实是很差的;但是,只要刚开始把LR设小点,后来逐步把LR提高到正常的大小,那么效果拔群,LB能够得到跟SB几乎一毛一样的training curve,以及基本相同的准确度。

基于这个观察,作者认为,LB不work的主要原因是

large minibatch sizes are challenged by optimization difficulties in early training

(至于为什么,这个跟Linear Scaling Learning Rate的assumption有关:简单来说,就是Linear Scaling Learning Rate这个trick是基于一定的assumption的,而这个assumption在网络权重急剧变化的时候——也就是刚开始训练的时候——是不成立的。所以,一开始就应用那么大的learning rate会出事。我解释的不是很清楚,具体可以去看原论文)

总结

上篇两篇论文各有千秋:ICLR那篇着重理论分析,用漂亮的实验验证了Sharp-minimum和Flat-minimum的区别,启发性非常大,但是给出的解决方案不是很令人信服;Facebook这篇直接从实战经验出发,实验和解释都比较令人信服,不过理论上相对弱些。

对于两者的Claim,其实不能说谁对谁错,因为两者的实验方法不一样;ICLR那篇没有应用Linear Scaling Learning Rate而是直接应用了ADAM来作为optimizer,得出的结果跟Facebook的肯定不能直接相比。如果ICLR那篇论文的作者可以使用Facebook的方法论重新跑实验的话,说不定得出的结论会有很大不同。甚至说,双方的结论其实不完全互斥,而是可以被统一成一个理论(比如我现在拍脑袋想的:刚开始训练的时候,Large-batch得出来的梯度不准确,所以如果设的learning rate太大,就更加容易陷入Sharp-minimum出不来,从而影响到后面的优化,之类之类的)。

「真诚赞赏,手留余香」

本文转载自:https://zhuanlan.zhihu.com/p/27349632

共有 人打赏支持
Airship
粉丝 34
博文 852
码字总数 18996
作品 0
南京
高级程序员
Deep Learning方向的paper

个人阅读的Deep Learning方向的paper整理,分了几部分吧,但有些部分是有交叉或者内容重叠,也不必纠结于这属于DNN还是CNN之类,个人只是大致分了个类。目前只整理了部分,剩余部分还会持续更...

langb2014
2016/03/06
0
0
前沿 | 利用遗传算法优化神经网络:Uber提出深度学习训练新方式

  选自Uber   作者:Kenneth O. Stanley、Jeff Clune   机器之心编译   参与:陈韵竹、刘晓坤      在深度学习领域,对于具有上百万个连接的多层深度神经网络(DNN),现在往往通...

机器之心
2017/12/22
0
0
学界 | 深度神经网络为什么不易过拟合?傅里叶分析发现固有频谱偏差

  选自arXiv   作者:Naism Rahaman等   机器之心编译   参与:Geek AI、刘晓坤      过参数化的深度神经网络是一类表达能力极强的函数,甚至能 100% 记住随机数据。这向我们提出...

机器之心
07/15
0
0
深度学习加速器Layer Normalization-LN

前面介绍了Batch Normalization(BN),公众号菜单栏可以获得文章链接,今天介绍一种和BN类似的深度学习normalize算法Layer Normalization(LN)。 LN提出:BN针对一个minibatch的输入样本,...

lqfarmer
2017/05/25
0
0
深度神经网络:攻陷语音识别最后堡垒的杀手锏?

【赛迪网讯】1月30日消息,尽管手机终端上各种语音助手的混战正如火如荼,但对于一些有着浓重口音的用户而言,语音助手的体验似乎远没有宣传的那么好:语音助手听不懂自己的话,这才是最大的...

修真0
2013/02/06
1K
4
斯坦福完全可解释深度神经网络:你需要用决策树搞点事

  选自Stanford   机器之心编译   参与:路雪、黄小天、刘晓坤      近日,斯坦福大学计算机科学博士生 Mike Wu 发表博客介绍了他对深度神经网络可解释性的探索,主要提到了树正则...

机器之心
01/10
0
0
Hinton提出泛化更优的「软决策树」:可解释DNN具体决策

近日,针对泛化能力强大的深度神经网络(DNN)无法解释其具体决策的问题,深度学习殿堂级人物 Geoffrey Hinton 等人发表 arXiv 论文提出「软决策树」(Soft Decision Tree)。相较于从训练数...

zchang81
2017/11/29
0
0
深度神经网络(DNN)模型与前向传播算法(转载)

转载地址:点击打开链接 深度神经网络(Deep Neural Networks, 以下简称DNN)是深度学习的基础,而要理解DNN,首先我们要理解DNN模型,下面我们就对DNN的模型与前向传播算法做一个总结。 1....

qianyi_wei
04/23
0
0
我从吴恩达深度学习课程中学到的21个心得:加拿大银行首席分析师“学霸“笔记分享

大数据文摘作品 编译:新知之路、小饭盆、钱天培 今年8月,吴恩达的深度学习课程正式上线,并即刻吸引了众多深度学习粉丝的“顶礼膜拜”。一如吴恩达此前在Coursera上的机器学习课程,这几门...

dzjx2eotaa24adr
2017/12/07
0
0
详解卷积神经网络(CNN)在语音识别中的应用

欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者:侯艺馨 前言 总结目前语音识别的发展现状,dnn、rnn/lstm和cnn算是语音识别中几个比较主流的方向。2012年,微软邓力和俞栋老...

腾讯云社区
2017/12/01
0
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

【面试题】盲人坐飞机

有100位乘客乘坐飞机,其中有一位是盲人,每位乘客都按自己的座位号就坐。由于盲人看不见自己的座位号,所以他可能会坐错位置,而自己的座位被占的乘客会随便找个座位就坐。问所有乘客都坐对...

garkey
今天
0
0
谈谈神秘的ES6——(二)ES6的变量

谈谈神秘的ES6——(二)ES6的变量 我们在《零基础入门JavaScript》的时候就说过,在ES5里,变量是有弊端的,我们先来回顾一下。 首先,在ES5中,我们所有的变量都是通过关键字var来定义的。...

JandenMa
今天
1
0
arts-week1

Algorithm 594. Longest Harmonious Subsequence - LeetCode 274. H-Index - LeetCode 219. Contains Duplicate II - LeetCode 217. Contains Duplicate - LeetCode 438. Find All Anagrams ......

yysue
今天
0
0
NNS拍卖合约

前言 关于NNS的介绍,这里就不多做描述,相关的信息可以查看NNS的白皮书http://doc.neons.name/zh_CN/latest/nns_background.html。 首先nns中使用的竞价货币是sgas,关于sgas介绍可以戳htt...

红烧飞鱼
今天
1
0
Java IO类库之管道流PipeInputStream与PipeOutputStream

一、java管道流介绍 在java多线程通信中管道通信是一种重要的通信方式,在java中我们通过配套使用管道输出流PipedOutputStream和管道输入流PipedInputStream完成线程间通信。多线程管道通信的...

老韭菜
今天
0
0
用Python绘制红楼梦词云图,竟然发现了这个!

Python在数据分析中越来越受欢迎,已经达到了统计学家对R的喜爱程度,Python的拥护者们当然不会落后于R,开发了一个个好玩的数据分析工具,下面我们来看看如何使用Python,来读红楼梦,绘制小...

猫咪编程
今天
1
0
Java中 发出请求获取别人的数据(阿里云 查询IP归属地)

1.效果 调用阿里云的接口 去定位IP地址 2. 代码 /** * 1. Java中远程调用方法 * http://localhost:8080/mavenssm20180519/invokingUrl.action * @Title: invokingUrl * @Description: * @ret......

Lucky_Me
今天
1
0
protobuf学习笔记

相关文档 Protocol buffers(protobuf)入门简介及性能分析 Protobuf学习 - 入门

OSC_fly
昨天
0
0
Mybaties入门介绍

Mybaties和Hibernate是我们在Java开发中应用的比较多的两个ORM框架。当然,目前Mybaties正在慢慢取代Hibernate,这是因为相比较Hibernate而言Mybaties性能更好,响应更快,更加灵活。我们在开...

王子城
昨天
2
0
编程学习笔记之python深入之装饰器案例及说明文档[图]

编程学习笔记之python深入之装饰器案例及说明文档[图] 装饰器即在不对一个函数体进行任何修改,以及不改变整体的原本意思的情况下,增加函数功能的新函数,因为这个新函数对旧函数进行了装饰...

原创小博客
昨天
1
0

没有更多内容

加载失败,请刷新页面

加载更多

下一页

返回顶部
顶部