FlashAttention-4正式发布:算法流水线大改,矩阵乘法级速度

FlashAttention-4正式发布:算法流水线大改,矩阵乘法级速度

 

文章摘要


【关 键 词】 深度学习注意力GPU大模型性能提升

经过一年开发,深度学习底层优化技术FlashAttention推出大版本更新FlashAttention-4,专门适配新一代Blackwell架构GPU在Blackwell GPU上,注意力机制的执行速度现在几乎与矩阵乘法一样快

此前FlashAttention-3主要针对Hopper H100架构优化,当前AI行业已转向部署Blackwell架构系统,该架构延续了硬件非对称扩展趋势:张量核心吞吐量增长速度远快于共享内存带宽、特殊函数单元等其他硬件资源,从Hopper H100到Blackwell B200,BF16张量核心吞吐量提升2.25倍,但特殊函数单元数量和共享内存带宽基本保持不变。这种不对称性打破了注意力性能完全由矩阵乘法速度决定的传统认知,实际分析显示,当前注意力前向传播的瓶颈是用于Softmax指数运算的特殊函数单元,反向传播瓶颈则是共享内存带宽。

FlashAttention-4采用算法与内核协同设计针对性解决瓶颈:为前向、反向传播分别设计新软件流水线,最大化张量核心计算、softmax计算与内存操作的重叠执行;前向传播通过多项式近似在FMA单元仿真指数函数提升吞吐量,引入条件式softmax重缩放,可避免90%的softmax重新缩放,缓解特殊函数单元瓶颈;反向传播利用张量内存存储中间结果,结合Blackwell新增的2-CTA MMA模式,将原子归约次数减少一半,降低共享内存访问压力,同时支持确定性执行保障训练可复现,新增的tile调度器解决了因果掩码和变长序列带来的负载不均衡问题。在B200(BF16)上,FlashAttention-4最高可达1605TFLOPs/s,实现71%的张量核心利用率,比cuDNN 9.13快1.3倍,比Triton快2.7倍,测试显示前向传播比cuDNN 9.13快1.1-1.3倍,比Triton实现快2.1-2.7倍,长序列场景下反向传播性能始终优于其他对比方案。

FlashAttention-4完全使用CuTe-DSL实现,相比C++模板编译,编译时间缩短20-30倍,安装编译仅需数秒。目前PyTorch官方已为FlexAttention添加FlashAttention-4后端,可自动生成修改代码并通过JIT编译为自定义注意力变体实例化FlashAttention-4,在算力受限工作负载下,相比Triton仍可实现1.2倍到3.2倍的性能提升,让研究人员无需在灵活性和高性能之间做取舍。该更新被认为是里程碑成果,相比FlashAttention-3性能提升2-3倍,将惠及所有前沿大模型,可带来更长有效上下文窗口、更低推理成本与更强规模化推理能力。(全文约760字)

原文和模型


【原文链接】 阅读原文 [ 1963字 | 8分钟 ]
【原文作者】 机器之心
【摘要模型】 doubao-seed-2-0-lite-260215
【摘要评分】 ★★★☆☆

© 版权声明
“绘蛙”

相关文章

“讯飞星辰”

暂无评论

暂无评论...