

由斯坦福大学提出的FlashAttention方法,让使用更长sequence计算Attention成为可能,并且通过线性级别的增长来节省内存以及加速计算。因为FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。

▐ 主要内容
▐ 主要操作

背景知识:
上图的左图,表示存储结构,可以简单理解为:SRAM表示缓存,HBM表示显存,DRAM表示内存。
tiling
在不访问整个输入的情况下优化attention计算,并减少相关计算量。重构attention计算,将输入分割成块,并对分块进行多次传递,从而逐步执行attention计算(该步骤称为tiling)。
如上图所示,FlashAttention 使用tiling来防止在相对较慢的 GPU显存上实现大型 𝑁 × 𝑁 注意力矩阵(虚线框)计算。在外部循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块,并将它们加载到快速片上 SRAM。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。
将输入Q、K、V矩阵分成很多块,将它们从较慢的HBM加载到较快的SRAM,然后在SRAM计算关于这些块的注意力输出。对每个块的计算结果缩放之后进行add操作,则得到正确的结果,具体伪代码如图:
recomputing

▐ Block-Sparse FlashAttention

▐ 小结
总的来说,FlashAttention有如下优点:
hbm访问次数降低,所以计算更快
在sram中计算attention,并对于后向计算提前保留中间结果,所以显存占用更少
可以使用更长的sequence,使得模型训练效果更好
对于attention计算,加速明显。如果加上稀疏化处理,速度会更快。

▐ 主要内容
▐ 主要操作
减少非矩阵运算
背景知识:
吞吐量是指单位时间内完成的任务数量或数据处理量。在这个上下文中,吞吐量指的是执行矩阵乘法操作时的性能表现,以及执行其他非矩阵乘法操作时的性能表现。这句话的意思是,执行矩阵乘法操作时,系统能够以每单位时间处理更多的任务或数据,其数量可以高达非矩阵乘法操作时的16倍。这表明矩阵乘法操作在性能上比其他操作更加高效。



增加并行比例

Forward pass:对批量维度和头数维度进行并行化,如 FlashAttention 中所做的那样。对于外循环(在序列长度上),将它们调度到不需要彼此通信的不同线程块上,每个工作线程负责关注矩阵的一行block块。外循环每次处理一行block,内循环每次处理这一行中的一列block,这和FlashAttention处理方式是不同的。
Backward pass:不同列块之间唯一共享的计算是算法 2 中更新的dQ,其中我们需要将 dQ从 HBM 加载到 SRAM,然后在片上通过 dQ更新,并写回 HBM。我们使用原子添加在不同线程块之间进行通信以更新 dQ。我们也在序列长度维度上进行并行化,并为后向传递的每一列block块安排 1 个工作线程(和前向传递是反过来的)。
在warp上优化工作划分
在一个注意力计算的block内,在一个thread block的不同warp之间优化工作划分,以减少通信和共享内存的读/写。
在每个线程块内,我们也必须决定如何在不同的 warp 之间划分工作。我们通常每个线程块使用 4 或 8 个 warp,分区如上图所示。
Forward pass:对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。每个warp相乘得到 QK⊤ 的slice,然后它们需要与 V 的slice相乘并进行通信以将结果相加。这称为“split-K”方案。然而,这是低效的,因为所有 warp 都需要将其中间结果写入共享内存,进行同步,然后将中间结果相加。这些共享内存读/写会减慢 FlashAttention 中的前向传播速度。在 FlashAttention-2 中,我们将 Q 分成 4 个经线,同时保持所有经线均可访问 K 和 V。在每个扭曲执行矩阵乘法以获得 QK⊤ 切片后,它们只需与共享的 V 切片相乘即可获得相应的输出切片。warp 之间不需要通信。共享内存读/写的减少可以提高速度。
背景知识:
warp:由多个thread组成,是编程层面的概念。
flash1:k和v被分为4个不同的warp,q和k计算、再和v计算,每一次计算的中间结果都要写入共享内存,并在之后被读取。这样就增加了共享内存的读写次数、拖慢了速度。
flash2:将q分为4个不同的warp,然后计算qk、计算v。但是这里k和v不需要通信,所以计算v的时候,不需要新的内存读写。这样就减少了读写次数、加快了程序。
▐ 小结
FlashAttention-2加速实践
▐ 时间与显存的优化效果
对于qkv计算,比较FlashAttention2与custom pytorch、xformers(FlashAttention1)的时间与显存消耗。如果只考虑QKV计算,flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些。
flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些
test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000754, peak memory: 113 MB
flash attention time: 0.000103, speedup: 7.29; peak memory: 45 MB, save: 60%
xformers time: 0.000255, speedup: 2.95; peak memory: 63 MB, save: 44%
test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000703, peak memory: 131 MB
flash attention time: 0.000106, speedup: 6.63; peak memory: 57 MB, save: 56%
xformers time: 0.000252, speedup: 2.80; peak memory: 70 MB, save: 46%
test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000721, peak memory: 131 MB
flash attention time: 0.000106, speedup: 6.78; peak memory: 57 MB, save: 56%
xformers time: 0.000263, speedup: 2.74; peak memory: 70 MB, save: 46%
test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000704, peak memory: 131 MB
flash attention time: 0.000105, speedup: 6.71; peak memory: 57 MB, save: 56%
xformers time: 0.000249, speedup: 2.82; peak memory: 70 MB, save: 46%
test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000110, speedup: 6.35; peak memory: 57 MB, save: 56%
xformers time: 0.000254, speedup: 2.75; peak memory: 70 MB, save: 46%
test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000766, peak memory: 131 MB
flash attention time: 0.000106, speedup: 7.25; peak memory: 57 MB, save: 56%
xformers time: 0.000252, speedup: 3.04; peak memory: 70 MB, save: 46%
test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000684, peak memory: 131 MB
flash attention time: 0.000101, speedup: 6.77; peak memory: 57 MB, save: 56%
xformers time: 0.000268, speedup: 2.56; peak memory: 70 MB, save: 46%
test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000717, peak memory: 131 MB
flash attention time: 0.000110, speedup: 6.52; peak memory: 57 MB, save: 56%
xformers time: 0.000254, speedup: 2.82; peak memory: 70 MB, save: 46%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000721, peak memory: 131 MB
flash attention time: 0.000102, speedup: 7.10; peak memory: 57 MB, save: 56%
xformers time: 0.000251, speedup: 2.87; peak memory: 70 MB, save: 46%
▐ 精度损失比较
计算FlashAttention2对于注意力机制的精度损失,与pytorch的计算精度进行对比。
绝大部分用例都可以通过测试,并且符合要求:
dQ Pytorch mean diff: 0.000698089599609375
dK Pytorch mean diff: 0.0005950927734375
dV Pytorch mean diff: 0.000537872314453125
.Actual dropout fraction: 0.17163611948490143
Output max diff: 0.001953125
Output mean diff: 2.9206275939941406e-05
Pytorch max diff: 0.0029296875
Pytorch mean diff: 8.106231689453125e-05
Attention max diff: 0.000244140625
Attention Pytorch max diff: 0.000732421875
dQ max diff: 0.0025577545166015625
dK max diff: 0.00390625
dV max diff: 0.0078125
dQ mean diff: 3.904104232788086e-05
dK mean diff: 0.0001360177993774414
dV mean diff: 0.0001475811004638672
dQ Pytorch max diff: 0.00390625
dK Pytorch max diff: 0.004150390625
dV Pytorch max diff: 0.0078125
dQ Pytorch mean diff: 8.702278137207031e-05
dK Pytorch mean diff: 0.00025916099548339844
dV Pytorch mean diff: 0.0002474784851074219
.Actual dropout fraction: 0.17163611948490143
Output max diff: 0.015625
Output mean diff: 0.0002346038818359375
Pytorch max diff: 0.015625
Pytorch mean diff: 0.00064849853515625
Attention max diff: 0.001953125
Attention Pytorch max diff: 0.00390625
dQ max diff: 0.01953125
dK max diff: 0.033203125
dV max diff: 0.0625
dQ mean diff: 0.0003108978271484375
dK mean diff: 0.00109100341796875
dV mean diff: 0.0011749267578125
dQ Pytorch max diff: 0.01806640625
dK Pytorch max diff: 0.0390625
dV Pytorch max diff: 0.0625
dQ Pytorch mean diff: 0.00069427490234375
dK Pytorch mean diff: 0.0020751953125
dV Pytorch mean diff: 0.001953125
...
FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-False-dtype0] - assert False
FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-True-dtype0] - assert False
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype0] - AssertionError: assert 236.75 <= (2 * 0.0009765625)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype1] - AssertionError: assert 22144.0 <= (2 * 0.0078125)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype0] - AssertionError: assert 2.724609375 <= (2 * 0.001953125)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype1] - AssertionError: assert 95.5 <= (2 * 0.015625)

▐ 环境信息
NVIDIA A10, CUDA Version: 11.4, webui-1.5.1, eas推理平台
▐ 加速效果
xformers(flash1):
文生图(512*512)(batchsize=1) | 文生图(512*512)(batchsize=4) | |
unet耗时(s) | 1 1 1 1 |
4 4 4 4 |
unet耗时(it/s) (step = 20) |
11.11it/s 11.27it/s 11.27it/s 11.27it/s |
4.33it/s 4.33it/s 4.33it/s 4.33it/s |
文生图(512*512) | 文生图(512*512)(batchsize=4) | |
unet耗时(s) | 1 1 1 1 |
4 4 4 4 |
unet耗时(it/s) (step = 20) |
11.13it/s 11.75it/s 11.46it/s 11.92it/s |
4.69it/s 4.69it/s 4.69it/s 4.68it/s |
相对于xformers(flash1),xformers(flash2)提速:
unet过程提速 | |
文生图加速(一次生成1图) | (11.57-11.23)/11.23=3% |
文生图加速(一次生成4图) | (4.69-4.33)/4.33=8.3% |
▐ 精度比较
xformers(flash1)
文生图(512*512)_ouput1 | 文生图(512*512)_ouput2 |
![]() |
![]() |
xformers(flash2)
文生图(512*512)_ouput1 | 文生图(512*512)_ouput2 |
![]() |
![]() |
使用不同的加速方法,AIGC生成图像,均符合预期,无精度损失。
注:这里未固定seed,所以图像会有变化,但是生成效果符合预期。
▐ AIGC加速分析
SD模型自身特点
flash_attention2主要是针对qkv计算进行加速,sd的推理过程中还有很多别的计算。推理过程中,进行采样(去噪),具有大量的计算,qkv计算只是推理计算的一部分。对于大图,计算量也更大,qkv的计算比例也更大,所以可以得到更多的加速效果。
SD模型的网络结构:
SD社区代码特点
显卡性能特殊性

▐ 实验环境
▐ 加速效果
文生图(512*512) | 文生图(512*512)(batchsize=4) | |
unet耗时(s) |
|
3 3 3 3 |
unet耗时(it/s) (step = 20) |
17.06it/s 18.22it/s 17.36it/s 16.43it/s |
6.26it/s 6.27it/s 6.25it/s 6.25it/s |
unet过程提速 | |
文生图加速(一次生成1图) | (17.26-11.23)/11.23=54% |
文生图加速(一次生成4图) | (6.26-4.33)/4.33=45% |
使用xformers(flash2)+fastunet加速方法,AIGC生成图像,结果符合预期,无精度损失。
文生图(512*512)_ouput1 | 文生图(512*512)_ouput2 |
![]() |
![]() |
生图过程主要有两部分耗时:controlnet与unet
旧方法:xformers 0.0.20,使用flash attention1加速sd(unet+controlnet)
新方法:1.当前的fastunet只加速unet里的attention(换为flash attention2)。2.xformers0.0.21加速包括controlnet在内的所有attention(换为flash attention2)。3.fastunett还对其他算子也做了一些fuse操作,也起到了加速效果。
fastunet和xformers0.0.21加速的底层逻辑,都是使用flash attention2优化attention。fastunet和xformers0.0.21叠加使用,可以最大程度起到加速效果。新的加速方法主要针对attention计算进行优化,所以在unet及其attention部分会有更高比例的加速。

近年来,让 Transformers 能够处理更长的序列长度一直备受关注。这一发展有助于提升语言建模和高分辨率图像理解的能力,并为音频和视频生成等新的应用场景带来了机遇。FlashAttention方法使得使用更长的序列计算注意力成为可能,并通过线性级别的增长来节省内存并加速计算。这一方法为处理长序列的Transformer模型提供了一种有效的解决方案。最新提出的FlashAttention-2,也进一步实现了attention计算速度的大幅度提升。
当我们一直在关注GPU显存大小以及计算能力的时候,FlashAttention关注了GPU显存以外的SRAM,从而优化attention计算。也为我们解决问题提供了思考,即在主流关注的技术点以外,还有一些被忽视的但依旧可以解决问题的思路。面对实际效果与理论效果的差距,FlashAttention-2则进一步找到gap原因,通过关注矩阵运算、序列并行、工作分区等问题,优化计算效果。这也提醒我们,对于性能问题的解决,从软硬件结合的角度出发,才能更充分的解决问题。
在AIGC领域的生图任务中,使用diffusion model进行相关计算,需要大量时间完成生图过程。所以,通过FlashAttention-2等多种加速方法进一步提升AIGC的生图效率,具有深刻意义。我们团队致力于家装行业AIGC进行相关研发,以提高家装AI模型的效果。我们希望与对此方向感兴趣的同学一起探讨和交流。
团队介绍
本文分享自微信公众号 - 大淘宝技术(AlibabaMTT)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。