高效 AI · Transformer

FlashAttention:从 GPU 内存读写里挤出的注意力加速

FlashAttention 保持注意力计算精确,但让算法具备 IO awareness,通过 tiling 减少慢速 GPU 内存访问,让长序列 Transformer 更快、更省显存。

一句话

FlashAttention 保持注意力计算精确,但让算法具备 IO awareness,通过 tiling 减少慢速 GPU 内存访问,让长序列 Transformer 更快、更省显存。

解决什么问题

self-attention 在长序列上很贵,因为时间和显存都随序列长度平方增长。很多替代方法选择近似注意力,但近似可能伤害模型质量,也不一定带来真实 wall-clock 加速。FlashAttention 的判断是:瓶颈不只是算术量,还在于 GPU 不同层级内存之间的数据搬运。

核心方法

FlashAttention 用 IO-aware 的 tiling 算法计算精确注意力。它避免在高带宽显存里物化完整 attention 矩阵,而是把数据块放到更快的片上 SRAM 中处理,并谨慎选择哪些内容重新计算、哪些内容存储。数学结果保持精确,改变的是内存调度方式。

关键结果

论文报告 FlashAttention 相比已有基线能更快训练 Transformer,在 GPT-2 和长程任务上都有明显加速。它还支持更长上下文并带来更好模型质量,包括在很长序列的 Path-X 类挑战上取得超过随机的表现。它的影响力来自一个现实点:优化真正贴合 GPU 硬件,而不只是在复杂度表格上好看。

为什么重要

FlashAttention 后来变成基础设施。它让注意力密集模型更快、显存压力更低,也让长上下文系统更实际。很多用户不会直接接触这篇论文,但现代训练和推理栈底层往往已经静默包含了类似 FlashAttention 的 kernel。

局限与存疑

FlashAttention 加速了注意力,但不像全新架构那样改变注意力本身的平方依赖。它也依赖精细 kernel 工程和具体硬件假设。更持久的启发是:在大模型时代,算法设计和系统设计已经不能分开看。

一句话:FlashAttention 把内存搬运当成注意力的真正瓶颈来优化。