Join our daily and weekly newsletters for the latest updates and exclusive content on industry-leading AI coverage. Learn More
Attention is a core component of the transformer architecture used in large language models (LLMs). But as LLMs grow larger and handle longer input sequences, the computational cost of attention becomes a bottleneck.
To address this challenge, researchers from Colfax Research, Meta, Nvidia, Georgia Tech, Princeton University, and Together AI have introduced FlashAttention-3, a new technique that significantly speeds up attention computation on Nvidia Hopper GPUs (H100 and H800).
FlashAttention-3 builds upon previous work on FlashAttention and FlashAttention-2 and further optimizes the use of resources on Nvidia Hopper GPUs to maximize performance and efficiency for LLM training and inference.
The challenge of attention computation in LLMs
One of the key innovations of transformers is the attention mechanism, which enables the model to compute the relationship between different tokens in an input sequence.
While the attention mechanism is very effective, it is also computationally expensive. The cost of attention computation grows quadratically with the length of the input sequence. As LLMs are scaled to handle longer and longer input sequences, the attention mechanism becomes a major bottleneck.
Furthermore, modern hardware accelerators such as GPUs are optimized for matrix multiplication (matmul) operations, which are the building blocks of deep learning models. These accelerators also have computational units for other types of operations such as exponentiation, but those units are hundreds of times slower than the matmul components.
Attention computations use a combination of matrix multiplications and other special functions that are not as optimized for GPUs.
For example, the softmax function, which is used to normalize the attention weights, is computationally more expensive than matrix multiplication. As a result, even though matrix multiplications account for most of the computations in attention, the overall computation can get bogged down by a small number of special functions.
One of the important aspects of optimizing attention computation is to schedule the workloads in a way that operations do not get blocked by each other and make efficient use of different types of memory components.
Making better use of hardware resources
FlashAttention, introduced in 2022, addressed the challenges of computing attention by reducing the number of memory reads and writes between GPU high bandwidth memory (HBM) and GPU on-chip static random access memory (SRAM) when doing attention computation. Instead of computing the attention weights for the entire sequence at once, FlashAttention breaks down the computation into smaller chunks, called “tiles,” that can be processed more efficiently on GPUs.
FlashAttention has been widely adopted and has contributed to increasing the context window of LLMs from a few thousand tokens to hundreds of thousands or even millions of tokens.
However, as hardware has improved, so have the possibilities of optimizing LLM computations. FlashAttention-2, introduced in 2023, further optimized the use of GPU resources, achieving up to 70% of the declared maximum performance on Nvidia A100 GPUs. However, the same optimizations did not transfer to the newer H100 GPUs. FlashAttention-2 only used 35% of H100’s maximum capacity.
FlashAttention-3
FlashAttention-3 takes advantage of new features in Nvidia Hopper GPUs to maximize performance. These features enable higher throughput on matrix multiplication operations, faster data transfer across different memory segments, and better efficiency on low-precision operations.
FlashAttention-3 introduces several innovations to improve the performance of attention computation on H100 GPUs.
FlashAttention-3 schedules operations in a way that maximizes the overlap between computation and the movement of data between different memory segments of the GPU. This reduces the time the GPU spends idle waiting for data to be transferred. It also interleaves the matrix multiplication and softmax operations to reduce the possibility of bottlenecks in computing attention values.
FlashAttention-3 also uses a special arrangement of operations for faster and more accurate computations of attention in quantized models. Quantization is a popular technique that reduces the size of models by using low-bit numbers to store their weights. The tradeoff of quantization is the possible loss of accuracy. FlashAttention-3 addresses this problem by carefully arranging the computations to minimize the impact of quantization on accuracy.
According to the researchers, FlashAttention-3 achieves up to 75% usage of the H100 GPU’s maximum capabilities. This translates to a 1.5–2x speedup compared to previous versions of FlashAttention for both training and running LLMs.
The benefits of FlashAttention-3
The faster attention computation offered by FlashAttention-3 has several implications for LLM development and applications.
Training LLMs is a computationally expensive process that can take weeks or even months. The fast attention computation offered by FlashAttention-3 can significantly reduce the time it takes to train LLMs, which can enable researchers and developers to experiment with larger models and datasets.
FlashAttention-3 can also help extend the context window of LLMs by enabling them to process longer sequences more efficiently. This can unlock new applications for LLMs in areas such as long-form document understanding and many-shot in-context learning.
And by using a higher percentage of GPU capacity, FlashAttention-3 can reduce the number of accelerators required to run LLMs and slash the cost of running models in production.
The researchers have open-sourced FlashAttention-3 under a permissive license and plan to integrate it into popular deep learning libraries such as PyTorch and Hugging Face Transformers. This will make it easier for researchers and developers to take advantage of the performance benefits of FlashAttention-3.
“We have seen that designing algorithms that take advantage of the hardware they run on can bring significant efficiency gains and unlock new model capabilities such as long context,” the researchers wrote in a blog post published by Together AI. “We look forward to future work on optimization for LLM inference, as well as generalizing our techniques to other hardware architectures.”
Source link