计算卷积神经网络浮点数运算量

原创
2018/09/24 18:06
阅读数 1.6W

前言

本文主要是介绍了,给定一个卷积神经网络的配置之后,如何大概估算它的浮点数运算量。

基于MXNet实现的计算网络模型运算量的小工具:

Python版本

Scala版本

在知乎上写了一篇博客算是对本文的内容的一些补充:https://zhuanlan.zhihu.com/p/65248401

正文

对于炼丹师来说,针对任务调整网络结构或者在做模型精简的时候,尤其是对模型的速度有要求的时候,

都想知道新模型的运算量有多大,虽然这个只是一个间接参考值,网络真正的运行速度还要考虑其他的

因素(具体解释可以参考shufflenet v2这篇文章)。那么对于给定一个卷积神经网络的模型定义,

该如何估算其浮点数运算量,对于卷积神经网络来说,卷积层的运算量是占网络总运算量的大头,而对于

一些像素级别任务,反卷积层也要算上,而全连接的权值大小是占网络权值的大头,运算量也有些。

所以一般来说把这三个层考虑上了,就能大概估算一个网络的运算量了。当然激活层(一般指relu),

BatchNorm层,比如残差网络会有elementwise sum,池化层,这些也会占一定的运算量。不过其实对于

BN来说,一般标配是conv + bn + relu,在上线使用过程中,可以把 bn 的权值融合进卷积层的权值中,

所以相当于没了bn这一层,变成 conv +relu,所以bn其实不用考虑,当然并不是所有的网络都是这样配置的。

网络各层运算量计算方法

卷积层运算量

对于卷积层来说,计算运算量的话其实很简单,因为卷积层的操作其实可以改写为矩阵乘法,这个思想很

经典了,把输入的feature map通过im2col操作生成一个矩阵,然后就可以和权值矩阵做乘法就得到了

输出的feature map,具体见下图:

画的有点难看,Cin是输入feature map 通道数,Hin和Win是输入feature map空间大小,

同样的Cout,Hout,Wout 对应输出feature map,然后 k表示卷积核空间大小。

首先最左边的权值矩阵很好理解,然后中间的矩阵就是输入feature map通过 im2col操作,

生成的矩阵,而输出矩阵的每个位置对应一个卷积核和输入的一个局域做一次卷积操作,

一个卷积核的大小就是  Cin * k * k,输入的一个区域要算上输入通道数,所以就对应了权值

矩阵的一行乘以输入矩阵的一列。

这样,卷积层的运算量就很明显了:   conv flops = Cout * Cin * k * k * Hout * Wout * BatchSize

如果还有偏置项的话,还要加上  BatchSize * Cout * Hout * Wout .

当然上面的公式没有考虑分组卷积的情况,但是demo的代码里面考虑了。

反卷积层运算量

反卷积其实也叫做转置卷积,其正反向传播,和卷积的正反向刚好相反,其运算量还是用画个图好解释。

这里要注意和上面卷积的符号定义区分开,这里的 Cin指的是反卷积层输入 feature map的

通道数,Hin 和 Win是输入的空间大小。所以Cout,Hout和Wout是反卷积层的输出大小。

而权值的形状为何是Cin * Cout * k * k,看图就很清晰了,首先我们知道反卷积的bp就是

卷积的fp,那么先从反卷积bp的角度来看,就相当于卷积的fp,中间的矩阵乘法就很好理解了。

然后反卷积的fp相当于把权值矩阵转置放到右边,就得到了反卷积的输出,然后这个输出并不是

最后的输出,还要通过col2im操作,把这个矩阵的值,回填到 Cout * Hout * Wout 这个矩阵里。

这样,反卷积层的运算量就很明显了:   deconv flops = Cin * Cout * k * k * Hin * Win * BatchSize

如果还有偏置项的话,还要加上  BatchSize * Cout * Hout * Wout .

所以计算反卷积的运算量,除了权值大小,输出大小(计算偏置),还需要知道输入的大小。

全连接层运算量

对与全连接层,即使矩阵向量乘法,其运算量就等于权值矩阵的大小,

所以 fullyconnected flops = BatchSize * Cout * Cin

Cout为全连接输出向量维度,Cin为输入维度。

如果有偏置项,还要加上: BatchSize * Cout。

池化层运算量

池化层的话就相当于卷积的简化版,这里根据池化的参数配置又可以分为两种情况,

如果是全局池化:

那么 pooling flops = BatchSize * Cin * Hin * Win 

如果是一般的池化:

那么从输出的角度来考虑,输出的feature map上每一个通道上的每一个点就对应着

输入feature map同样通道的上的一个k * k 区域的 max ,sum或者avg池化操作,

所以  pooling flops = BatchSize * Cout * Hout * Wout * k * k 

 

 

展开阅读全文
加载中
点击引领话题📣 发布并加入讨论🔥
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部