文档章节

[DNN] 尝试理解深度神经网络的Large-batch魔咒

Airship
 Airship
发布于 2017/06/13 09:33
字数 1585
阅读 40
收藏 0

[DNN] 尝试理解深度神经网络的Large-batch魔咒

[DNN] 尝试理解深度神经网络的Large-batch魔咒

张泰源张泰源

1 天前

最近贵司的“一小时训练ImageNet”论文在国内外各种刷屏(https://research.fb.com/publications/imagenet1kin1h/),看了一下,确实非常实用主义的文章,介绍很多有用的trick,包括系统实现上的很多坑都覆盖到了。其中谈到加速训练的难点之一是:需要用到更大的mini-batch size,而这通常会降低准确率,所以他们通过linear-scaling learning rate解决了这个问题。看到这里我对于这个难点产生了疑问——batch size越大,不应该训练的方差越小,随机性越小,从而能够更准确地拟合数据集么?

从一个对深度学习接触不多的人(比如我)的角度,这点确实有点反直觉。当batch-size不断增大,直到跟数据集一样大的时候,SGD (Stochastic Gradient Descent)就变成了最朴素的GD,一次梯度更新会扫描一遍所有的数据来算梯度。看教科书和在CMU上Machine Learning的时候被灌输的理念都是:SGD相对于GD,或者小的batch相对于大的batch,有助于更快收敛,但是准确度会下降。为什么到了深度神经网络这里就反过来了呢?

我的第一猜想是:神经网络的函数空间非常non-convex,所以mini-batch越小就越容易不断跳出local minima,寻找更好的最小值。但是自己马上感觉这个猜想有很多漏洞,不能自圆其说,所以我去查证了一下其他人的分析——Facebook的论文原文有提及过这个问题,以及ICLR 2017上也有一篇论文针对这个问题分析了一下。有趣的事,两篇文章的观点并不相同,Facebook的论文还轻踩了对方一下说“根据我们跑的实验这事不是你们说的那样儿的”。由于ICLR 的论文先出来,我们先看看它怎么说:

ICLR 2017: ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA https://openreview.net/pdf?id=H1oyRlYgg 

这篇文章主要研究了“为什么Large batch size会让错误率提高”的问题,提出了四个可能的猜想:

(i) LB methods over-fit the model;

(ii) LB methods are attracted to saddle points;

(iii) LB methods lack the explorative properties of SB methods and tend to zoom-in on the minimizer closest to the initial point;

(iv) SB and LB methods converge to qualitatively different minimizers with differing generalization properties.

然后通过实验,得出了支持(iii)和(iv)的证据。也就是说,主要是两点原因:

1) LB (Large-Batch) 方法探索性太差,容易在离起始点附近很近的地方停下来

2) LB和SB由于训练方式上的差异,最终会导致它们最终收敛的点具有一些数学属性的差异

#1 很好理解,跟我前面的猜想有点类似。这里着重谈谈#2 - 文章谈到,LB方法会收敛到Sharp-minimum,而SB方法会收敛到Flat-minimum。这两种minimum的差别如图所示:

在同样的Bias下,明显Flat的曲线比Sharp的曲线更加接近真实情况,所以Flat Minimum的generalization performance更好。

然后,基于这个假设,他们给出的解决方案是:先用SB方法训练几个epoch,让它先探索一下,找到一个比较Flat的区域,再用LB方法慢慢收敛到正确的地方。论文给出了performance vs. # of epoch trained with SB,但个人感觉不是很有说服力。。。

Facebook: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

再回到Facebook这篇文章,作者认为,LB之所以不work,不是因为上面那篇论文提到的泛化能力的问题,而主要是一个optimization issue(我的理解是优化过程/优化算法的问题)。文章没有给出理论分析, 而是直接给出了实验数据:首先,这篇论文是基于“Linear Scaling Learning Rate”来做的,简单来说,假如说原来batch size是256,learning rate是0.1;那么当把batch size设成8192的时候,learning rate就设成3.2 。batch size翻多少倍,learning rate就翻多少倍。然后,基于这个方法,论文作者发现,如果用LB方法,刚开始就用很大的learning rate的话,效果其实是很差的;但是,只要刚开始把LR设小点,后来逐步把LR提高到正常的大小,那么效果拔群,LB能够得到跟SB几乎一毛一样的training curve,以及基本相同的准确度。

基于这个观察,作者认为,LB不work的主要原因是

large minibatch sizes are challenged by optimization difficulties in early training

(至于为什么,这个跟Linear Scaling Learning Rate的assumption有关:简单来说,就是Linear Scaling Learning Rate这个trick是基于一定的assumption的,而这个assumption在网络权重急剧变化的时候——也就是刚开始训练的时候——是不成立的。所以,一开始就应用那么大的learning rate会出事。我解释的不是很清楚,具体可以去看原论文)

总结

上篇两篇论文各有千秋:ICLR那篇着重理论分析,用漂亮的实验验证了Sharp-minimum和Flat-minimum的区别,启发性非常大,但是给出的解决方案不是很令人信服;Facebook这篇直接从实战经验出发,实验和解释都比较令人信服,不过理论上相对弱些。

对于两者的Claim,其实不能说谁对谁错,因为两者的实验方法不一样;ICLR那篇没有应用Linear Scaling Learning Rate而是直接应用了ADAM来作为optimizer,得出的结果跟Facebook的肯定不能直接相比。如果ICLR那篇论文的作者可以使用Facebook的方法论重新跑实验的话,说不定得出的结论会有很大不同。甚至说,双方的结论其实不完全互斥,而是可以被统一成一个理论(比如我现在拍脑袋想的:刚开始训练的时候,Large-batch得出来的梯度不准确,所以如果设的learning rate太大,就更加容易陷入Sharp-minimum出不来,从而影响到后面的优化,之类之类的)。

「真诚赞赏,手留余香」

本文转载自:https://zhuanlan.zhihu.com/p/27349632

共有 人打赏支持
Airship
粉丝 39
博文 932
码字总数 19883
作品 0
南京
高级程序员
私信 提问
Deep Learning方向的paper

个人阅读的Deep Learning方向的paper整理,分了几部分吧,但有些部分是有交叉或者内容重叠,也不必纠结于这属于DNN还是CNN之类,个人只是大致分了个类。目前只整理了部分,剩余部分还会持续更...

langb2014
2016/03/06
0
0
前沿 | 利用遗传算法优化神经网络:Uber提出深度学习训练新方式

  选自Uber   作者:Kenneth O. Stanley、Jeff Clune   机器之心编译   参与:陈韵竹、刘晓坤      在深度学习领域,对于具有上百万个连接的多层深度神经网络(DNN),现在往往通...

机器之心
2017/12/22
0
0
深度学习加速器Layer Normalization-LN

前面介绍了Batch Normalization(BN),公众号菜单栏可以获得文章链接,今天介绍一种和BN类似的深度学习normalize算法Layer Normalization(LN)。 LN提出:BN针对一个minibatch的输入样本,...

lqfarmer
2017/05/25
0
0
学界 | 深度神经网络为什么不易过拟合?傅里叶分析发现固有频谱偏差

  选自arXiv   作者:Naism Rahaman等   机器之心编译   参与:Geek AI、刘晓坤      过参数化的深度神经网络是一类表达能力极强的函数,甚至能 100% 记住随机数据。这向我们提出...

机器之心
2018/07/15
0
0
斯坦福完全可解释深度神经网络:你需要用决策树搞点事

  选自Stanford   机器之心编译   参与:路雪、黄小天、刘晓坤      近日,斯坦福大学计算机科学博士生 Mike Wu 发表博客介绍了他对深度神经网络可解释性的探索,主要提到了树正则...

机器之心
2018/01/10
0
0

没有更多内容

加载失败,请刷新页面

加载更多

二进制相关

二进制 众所周知计算机使用的是二进制,数字的二进制是如何表示的呢? 实际就是逢二进一。比如 2 用二进制就是 10。那么根据此可以推算出 5的二进制等于 10*10+1 即为 101。 在计算机中,负数以...

NotFound403
昨天
2
0
day22:

1、写一个getinterface.sh 脚本可以接受选项[i,I],完成下面任务: 1)使用格式:getinterface.sh [-i interface | -I ip] 2)当用户使用-i选项时,显示指定网卡的IP地址;当用户使用-I选项...

芬野de博客
昨天
2
0
Spring Cloud Alibaba基础教程:使用Nacos实现服务注册与发现

自Spring Cloud Alibaba发布第一个Release以来,就备受国内开发者的高度关注。虽然Spring Cloud Alibaba还没能纳入Spring Cloud的主版本管理中,但是凭借阿里中间件团队的背景,还是得到不少...

程序猿DD
昨天
4
0
Java并发编程:深入剖析ThreadLocal

ThreadLocal 的理解 ThreadLocal,很多地方叫线程本地变量,或线程本地存储。ThreadLocal为变量在每个线程中都创建了一个副本,每个线程可以访问自己内部的副本变量。===》解决的问题是线程间...

细节探索者
昨天
3
0
【Python3之异常处理】

一、错误和异常 1.错误 代码运行前的语法或者逻辑错误 语法错误(这种错误,根本过不了python解释器的语法检测,必须在程序执行前就改正) def test: ^SyntaxError: invalid...

dragon_tech
昨天
2
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部