斯坦福研究院推出FlashAttention-2:为长上下文语言模型带来速度和效率的飞跃

Stanford Research Institute launches FlashAttention-2 a leap in speed and efficiency for long-context language models.

在过去的一年中,自然语言处理领域取得了显著的进展,因为出现了具有更长上下文的语言模型。其中包括具有32k上下文长度的GPT-4,具有65k上下文长度的MosaicML的MPT,以及具有100k上下文长度的Anthropic的Claude。随着长文档查询和故事创作等应用的不断增长,对具有扩展上下文的语言模型的需求变得明显。然而,挑战在于扩展Transformer的上下文长度,因为它们的注意力层的计算和内存需求随着输入序列长度的增长而呈二次方增长。

应对这一挑战,一年前发布的创新算法FlashAttention在各个组织和研究实验室中迅速得到了采用。该算法成功加速了注意力计算,同时减少了其内存占用,而不会牺牲准确性或近似结果。在初始发布时,FlashAttention比优化基准线快2-4倍,证明了其创新性的突破。然而,它仍然有未开发的潜力,因为它没有达到在A100 GPU上可达124 TFLOPs/s的高速优化矩阵乘法(GEMM)操作。

为了迈出下一个飞跃,FlashAttention的开发人员现已推出FlashAttention-2,这是一个显著超越其前身的版本。FlashAttention-2利用了Nvidia的CUTLASS 3.x和CuTe核心库,实现了惊人的2倍加速,在A100 GPU上达到230 TFLOPs/s的速度。此外,在GPT风格语言模型的端到端训练中,FlashAttention-2实现了高达225 TFLOPs/s的训练速度,模型FLOP利用率达到72%。

FlashAttention-2的关键改进在于其更好的并行性和工作分区策略。最初,FlashAttention通过批大小和头数并行化,有效利用了GPU上的计算资源。然而,对于较小批次大小或较少头数的长序列,FlashAttention-2现在通过序列长度维度进行并行化,从而在这些情况下实现了显著的加速。

另一个改进涉及在每个线程块内有效地分区不同的线程束之间的工作。在FlashAttention中,将K和V分成四个线程束,同时让Q可被所有线程束访问,称为“切片-K”方案,导致了不必要的共享内存读写,减慢了计算速度。FlashAttention-2采用了不同的方法,现在将Q分成四个线程束,同时让K和V可被所有线程束访问。这消除了线程束之间的通信需求,显著减少了共享内存的读写,进一步提高了性能。

FlashAttention-2引入了几个新功能,扩大了其适用性并增强了其功能。它现在支持高达256个头维度,适用于像GPT-J、CodeGen、CodeGen2和StableDiffusion 1.x这样的模型,为更多的加速和节省内存的机会打开了大门。此外,FlashAttention-2还采用了多查询注意力(MQA)和分组查询注意力(GQA)变种,其中查询的多个头可以参与到相同的键和值的头中,从而提高了推理吞吐量和性能。

FlashAttention-2的性能真是令人印象深刻。在A100 80GB SXM4 GPU上进行基准测试,与其前身相比,它实现了大约2倍的加速,并且与PyTorch中的标准注意力实现相比,速度提高了多达9倍。此外,当用于GPT风格模型的端到端训练时,FlashAttention-2在A100 GPU上可以达到225 TFLOPs/s的速度,相比已经高度优化的模型,实现了1.3倍的端到端加速。

展望未来,FlashAttention-2的潜在应用前景令人期待。使用相同价格训练具有16k更长上下文的模型,相较于之前的8k上下文模型,这项技术可以帮助分析长篇书籍、报告、高分辨率图像、音频和视频。在H100 GPU和AMD GPU等设备上实现更广泛的适用性和对新数据类型(如fp8)的优化计划正在进行中。此外,将FlashAttention-2的低级优化与高级算法变更相结合,可以为训练具有前所未有的更长上下文的AI模型铺平道路。与编译器研究人员合作以提高可编程性也在不远的将来,为下一代语言模型带来光明的未来。