FlashAttention: How Tri Dao Made Attention 4x Faster
Before FlashAttention, the standard implementation of transformer attention was doing something fundamentally wasteful. It was computing a massive N x N attention matrix, writing it to GPU memory, reading it back, and then using it. For long sequences, this intermediate matrix was enormous, and the time spent moving it between GPU memory tiers dominated the actual math.
Tri Dao's insight was that attention isn't a compute-bound problem. It's a memory-bandwidth problem. And by restructuring how the computation flows through the GPU's memory hierarchy, you can make it dramatically faster without approximating anything.
The Problem With Standard Attention
In a standard transformer, the attention mechanism computes Q * K^T to get an N x N matrix of attention scores (where N is the sequence length), applies softmax, and then multiplies by V. The naive implementation materializes that full N x N matrix in GPU high-bandwidth memory (HBM).
The issue is that HBM reads and writes are slow relative to the actual arithmetic. GPU SRAM (on-chip memory) is fast but small (around 20MB on an A100, compared to 40-80GB of HBM). The standard attention implementation is constantly shuttling data between SRAM and HBM, and most of the wall-clock time is spent on those memory transfers, not on the matrix multiplications themselves.
For a sequence length of 4096 with a hidden dimension of 128, the attention score matrix alone is 4096 x 4096 x 2 bytes = 32MB per head. At longer sequences this gets worse quadratically. The GPU spends more time moving data than doing math.
How FlashAttention Works
FlashAttention restructures the attention computation using a technique called tiling. Instead of computing the full N x N attention matrix at once, it processes attention in small blocks that fit entirely in SRAM.
The algorithm:
- Divide Q, K, and V matrices into blocks that fit in SRAM.
- For each block of Q, load it into SRAM once.
- Iterate over blocks of K and V, computing partial attention scores in SRAM.
- Accumulate the softmax and output incrementally using an online softmax algorithm (tracking running max and sum statistics).
- Write only the final output back to HBM. The N x N attention matrix is never materialized.
The key mathematical trick is the online softmax. Normal softmax requires knowing the maximum value across the entire row before you can compute any outputs. The online version maintains running statistics that get corrected as new blocks are processed, producing numerically identical results without needing the full row in memory at once.
The result: FlashAttention is exact (not an approximation), uses O(N) memory instead of O(N^2), and runs 2-4x faster than standard attention on GPUs like the A100.
FlashAttention-2: Better Parallelism
FlashAttention-2 (July 2023) improved on the original by rethinking how work is distributed across the GPU's processing units.
The main changes:
- Reduced non-matmul FLOPs. The original FlashAttention spent a significant fraction of time on rescaling, softmax, and other non-tensor-core operations. V2 restructured the algorithm to minimize these.
- Better parallelism across the sequence length dimension. V1 parallelized over batch size and number of heads. V2 also parallelized over the sequence length, which matters for long sequences with small batch sizes (common in inference).
- Better work partitioning between warps within a thread block, reducing shared memory reads/writes.
These changes brought FlashAttention-2 to around 230 TFLOPs/s on an A100, roughly 2x faster than FlashAttention-1 and close to the theoretical maximum of the hardware.
FlashAttention-3: Exploiting Hopper
FlashAttention-3 (July 2024) was designed specifically for NVIDIA's Hopper architecture (H100 GPUs), which introduced new hardware capabilities that the previous versions couldn't use.
Three key techniques:
Warp specialization with asynchronous execution. Hopper has a new asynchronous programming model where different warp groups can overlap computation and memory transfers. FlashAttention-3 designates some warps as "producers" (loading data) and others as "consumers" (doing math), running them concurrently.
Interleaved block-wise matmul and softmax. Instead of waiting for all matmuls to finish before computing softmax (or vice versa), FA-3 interleaves these operations to keep the tensor cores busy while softmax runs on the CUDA cores.
FP8 support with incoherent processing. FA-3 adds block-wise FP8 quantization for attention, achieving ~1.2 PFLOPs/s in FP8 while keeping numerical error 2.6x lower than naive FP8 attention through a technique called incoherent processing (applying random orthogonal transformations to spread quantization error).
The result is 740 TFLOPs/s in FP16 (75% hardware utilization) on H100.
Flash-Decoding: Fixing the Decode Phase
There's a specific problem during autoregressive decoding (generating one token at a time) that the main FlashAttention papers didn't fully address. During decoding, the query length is 1 (just the new token), but the key/value length can be very long (the entire context). Standard FlashAttention parallelizes over batch size, heads, and query length, but when query length is 1, there's almost nothing to parallelize.
Flash-Decoding (October 2023) added a new parallelization dimension: splitting across the key/value sequence. The KV sequence is divided into chunks, each chunk computes partial attention with the single query in parallel, and the results are combined with a log-sum-exp correction.
This is especially important for long-context inference. Without Flash-Decoding, a single decode step with 100K context might use less than 1% of the GPU's compute capacity. With it, you can actually saturate the hardware, achieving up to 8x faster decoding for long sequences.
Why This Matters (And Why ASICs Go Further)
FlashAttention and its descendants are now built into essentially every production LLM serving system. vLLM, TensorRT-LLM, SGLang, and every major inference provider uses some variant of these kernels. The impact is real: prefill got 2-4x faster, long-context decoding became usable, and 128K+ context windows went from impractical to standard.
But here's the thing worth noting. FlashAttention exists because GPUs have a fundamental architectural mismatch for inference workloads. The entire family of techniques is about working around the bottleneck of moving data between HBM and SRAM on a chip that was designed for graphics rendering and general-purpose parallel compute.
General Compute is the only neocloud built entirely on inference-optimized ASICs instead of NVIDIA GPUs. The memory bandwidth constraints that FlashAttention was built to solve are addressed at the hardware level on these chips. There's no need to tile around a slow memory bus because the memory architecture is purpose-built for the access patterns that transformer inference actually needs. Combined with our own optimizations like disaggregated inference, this is why we're fundamentally faster than GPU-based providers, even ones running FlashAttention.
If you want to see the difference that purpose-built inference hardware makes, sign up at generalcompute.com and get $5 in free credit to try it out.
Papers and References
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023)
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (Shah et al., 2024)
- Flash-Decoding for long-context inference (Stanford CRFM, 2023)