文档章节

[DNN] 尝试理解深度神经网络的Large-batch魔咒

Airship
 Airship
发布于 2017/06/13 09:33
字数 1585
阅读 29
收藏 0

[DNN] 尝试理解深度神经网络的Large-batch魔咒

[DNN] 尝试理解深度神经网络的Large-batch魔咒

张泰源张泰源

1 天前

最近贵司的“一小时训练ImageNet”论文在国内外各种刷屏(https://research.fb.com/publications/imagenet1kin1h/),看了一下,确实非常实用主义的文章,介绍很多有用的trick,包括系统实现上的很多坑都覆盖到了。其中谈到加速训练的难点之一是:需要用到更大的mini-batch size,而这通常会降低准确率,所以他们通过linear-scaling learning rate解决了这个问题。看到这里我对于这个难点产生了疑问——batch size越大,不应该训练的方差越小,随机性越小,从而能够更准确地拟合数据集么?

从一个对深度学习接触不多的人(比如我)的角度,这点确实有点反直觉。当batch-size不断增大,直到跟数据集一样大的时候,SGD (Stochastic Gradient Descent)就变成了最朴素的GD,一次梯度更新会扫描一遍所有的数据来算梯度。看教科书和在CMU上Machine Learning的时候被灌输的理念都是:SGD相对于GD,或者小的batch相对于大的batch,有助于更快收敛,但是准确度会下降。为什么到了深度神经网络这里就反过来了呢?

我的第一猜想是:神经网络的函数空间非常non-convex,所以mini-batch越小就越容易不断跳出local minima,寻找更好的最小值。但是自己马上感觉这个猜想有很多漏洞,不能自圆其说,所以我去查证了一下其他人的分析——Facebook的论文原文有提及过这个问题,以及ICLR 2017上也有一篇论文针对这个问题分析了一下。有趣的事,两篇文章的观点并不相同,Facebook的论文还轻踩了对方一下说“根据我们跑的实验这事不是你们说的那样儿的”。由于ICLR 的论文先出来,我们先看看它怎么说:

ICLR 2017: ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA https://openreview.net/pdf?id=H1oyRlYgg 

这篇文章主要研究了“为什么Large batch size会让错误率提高”的问题,提出了四个可能的猜想:

(i) LB methods over-fit the model;

(ii) LB methods are attracted to saddle points;

(iii) LB methods lack the explorative properties of SB methods and tend to zoom-in on the minimizer closest to the initial point;

(iv) SB and LB methods converge to qualitatively different minimizers with differing generalization properties.

然后通过实验,得出了支持(iii)和(iv)的证据。也就是说,主要是两点原因:

1) LB (Large-Batch) 方法探索性太差,容易在离起始点附近很近的地方停下来

2) LB和SB由于训练方式上的差异,最终会导致它们最终收敛的点具有一些数学属性的差异

#1 很好理解,跟我前面的猜想有点类似。这里着重谈谈#2 - 文章谈到,LB方法会收敛到Sharp-minimum,而SB方法会收敛到Flat-minimum。这两种minimum的差别如图所示:

在同样的Bias下,明显Flat的曲线比Sharp的曲线更加接近真实情况,所以Flat Minimum的generalization performance更好。

然后,基于这个假设,他们给出的解决方案是:先用SB方法训练几个epoch,让它先探索一下,找到一个比较Flat的区域,再用LB方法慢慢收敛到正确的地方。论文给出了performance vs. # of epoch trained with SB,但个人感觉不是很有说服力。。。

Facebook: Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour

再回到Facebook这篇文章,作者认为,LB之所以不work,不是因为上面那篇论文提到的泛化能力的问题,而主要是一个optimization issue(我的理解是优化过程/优化算法的问题)。文章没有给出理论分析, 而是直接给出了实验数据:首先,这篇论文是基于“Linear Scaling Learning Rate”来做的,简单来说,假如说原来batch size是256,learning rate是0.1;那么当把batch size设成8192的时候,learning rate就设成3.2 。batch size翻多少倍,learning rate就翻多少倍。然后,基于这个方法,论文作者发现,如果用LB方法,刚开始就用很大的learning rate的话,效果其实是很差的;但是,只要刚开始把LR设小点,后来逐步把LR提高到正常的大小,那么效果拔群,LB能够得到跟SB几乎一毛一样的training curve,以及基本相同的准确度。

基于这个观察,作者认为,LB不work的主要原因是

large minibatch sizes are challenged by optimization difficulties in early training

(至于为什么,这个跟Linear Scaling Learning Rate的assumption有关:简单来说,就是Linear Scaling Learning Rate这个trick是基于一定的assumption的,而这个assumption在网络权重急剧变化的时候——也就是刚开始训练的时候——是不成立的。所以,一开始就应用那么大的learning rate会出事。我解释的不是很清楚,具体可以去看原论文)

总结

上篇两篇论文各有千秋:ICLR那篇着重理论分析,用漂亮的实验验证了Sharp-minimum和Flat-minimum的区别,启发性非常大,但是给出的解决方案不是很令人信服;Facebook这篇直接从实战经验出发,实验和解释都比较令人信服,不过理论上相对弱些。

对于两者的Claim,其实不能说谁对谁错,因为两者的实验方法不一样;ICLR那篇没有应用Linear Scaling Learning Rate而是直接应用了ADAM来作为optimizer,得出的结果跟Facebook的肯定不能直接相比。如果ICLR那篇论文的作者可以使用Facebook的方法论重新跑实验的话,说不定得出的结论会有很大不同。甚至说,双方的结论其实不完全互斥,而是可以被统一成一个理论(比如我现在拍脑袋想的:刚开始训练的时候,Large-batch得出来的梯度不准确,所以如果设的learning rate太大,就更加容易陷入Sharp-minimum出不来,从而影响到后面的优化,之类之类的)。

「真诚赞赏,手留余香」

本文转载自:https://zhuanlan.zhihu.com/p/27349632

共有 人打赏支持
Airship
粉丝 38
博文 879
码字总数 18996
作品 0
南京
高级程序员
Deep Learning方向的paper

个人阅读的Deep Learning方向的paper整理,分了几部分吧,但有些部分是有交叉或者内容重叠,也不必纠结于这属于DNN还是CNN之类,个人只是大致分了个类。目前只整理了部分,剩余部分还会持续更...

langb2014
2016/03/06
0
0
学界 | 深度神经网络为什么不易过拟合?傅里叶分析发现固有频谱偏差

  选自arXiv   作者:Naism Rahaman等   机器之心编译   参与:Geek AI、刘晓坤      过参数化的深度神经网络是一类表达能力极强的函数,甚至能 100% 记住随机数据。这向我们提出...

机器之心
07/15
0
0
前沿 | 利用遗传算法优化神经网络:Uber提出深度学习训练新方式

  选自Uber   作者:Kenneth O. Stanley、Jeff Clune   机器之心编译   参与:陈韵竹、刘晓坤      在深度学习领域,对于具有上百万个连接的多层深度神经网络(DNN),现在往往通...

机器之心
2017/12/22
0
0
深度学习加速器Layer Normalization-LN

前面介绍了Batch Normalization(BN),公众号菜单栏可以获得文章链接,今天介绍一种和BN类似的深度学习normalize算法Layer Normalization(LN)。 LN提出:BN针对一个minibatch的输入样本,...

lqfarmer
2017/05/25
0
0
深度神经网络:攻陷语音识别最后堡垒的杀手锏?

【赛迪网讯】1月30日消息,尽管手机终端上各种语音助手的混战正如火如荼,但对于一些有着浓重口音的用户而言,语音助手的体验似乎远没有宣传的那么好:语音助手听不懂自己的话,这才是最大的...

修真0
2013/02/06
1K
4

没有更多内容

加载失败,请刷新页面

加载更多

Generator-ES6

基本概念 Generator 函数是 ES6 提供的一种异步编程解决方案,语法行为与传统函数完全不同。 Generator 函数有多种理解角度。语法上,首先可以把它理解成,Generator 函数是一个状态机,封装...

简心
25分钟前
4
0
FullCalendar日历插件说明文档

普通显示设置 属性 描述 默认值 header 设置日历头部信息。 如果设置为false,则不显示头部信息。包括left,center,right左中右三个位置,每个位置都可以对应以下不同的配置: title: 显示当...

ada_young
26分钟前
1
0
Redis知识总结--string的内部实现

SDS(Simple Dynamic String) String的数据结构是一个字节数组,但简单的获取数组长度的时间复杂度就是O(n),这对于单线程的redis来讲是不能接受的,因此string在redis中的实现是SDS类,SDS类...

looqy
36分钟前
2
0
SpringBoot开发案例之整合Dubbo分布式服务

前言 在 SpringBoot 很火热的时候,阿里巴巴的分布式框架 Dubbo 不知是处于什么考虑,在停更N年之后终于进行维护了。在之前的微服务中,使用的是当当维护的版本 Dubbox,整合方式也是使用的 ...

Java干货分享
42分钟前
5
0
美团团购订单系统优化记

团购订单系统简介 美团团购订单系统主要作用是支撑美团的团购业务,为上亿美团用户购买、消费提供服务保障。2015年初时,日订单量约400万~500万,同年七夕订单量达到800万。 目标 作为线上S...

Skqing
46分钟前
1
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部