Agent Readout
Flash Attention: Why Modern LLMs Run Faster With It
Flash Attention rewrites the attention computation to avoid moving a giant intermediate matrix in and out of GPU memory. Here is how the tiling and kernel fusion work, how v1, v2, and v3 evolved, and how to turn it on in PyTorch.
- Author
- General Compute
- Published
- 2026-06-03
- Tags
- flash-attention, attention, gpu, inference-optimization
Markdown body
If you have profiled a transformer and watched the GPU sit at low utilization while memory bandwidth pegs at 100%, you have met the problem Flash Attention solves. Attention is the core operation in every modern large language model, and the naive way to compute it spends most of its time shuffling a large intermediate matrix between fast on-chip memory and slow high-bandwidth memory. Flash Attention restructures the computation so that matrix never gets written out in full, and the result is a kernel that runs several times faster while using far less memory.
This post explains what Flash Attention actually does, why the standard implementation was so wasteful, how the algorithm avoids the waste through tiling and kernel fusion, and how the v1, v2, and v3 versions differ. It also covers how to enable it in PyTorch, which is usually a one-line change.
## Why Standard Attention Is Slow
Self-attention takes three matrices: queries Q, keys K, and values V, each of shape (sequence length, head dimension). The math is straightforward:
```
S = Q @ K^T # scores, shape (N, N)
P = softmax(S) # attention weights, shape (N, N)
O = P @ V # output, shape (N, d)
```
The catch is the N-by-N matrices S and P, where N is the sequence length. For a sequence of 8,192 tokens, each of those matrices has roughly 67 million entries per attention head. The standard implementation computes S in full, writes it to GPU memory, reads it back to apply softmax, writes P back out, then reads it again for the final multiply by V.
That back-and-forth is the bottleneck. A modern GPU like an H100 can do hundreds of teraflops of matrix math, but its high-bandwidth memory (HBM) moves data at a few terabytes per second, which sounds fast until you realize how much data attention moves. The on-chip SRAM that sits next to the compute units is far faster but tiny, on the order of tens of megabytes for the whole chip. The giant S and P matrices do not fit in SRAM, so they live in HBM, and every read and write of them costs bandwidth.
Attention is what people call memory-bound. The arithmetic is cheap relative to the data movement. The GPU spends its time waiting on memory rather than doing math, which is why utilization looks bad. The fix is not to do less math. It is to stop moving the intermediate matrices around.
## The Core Idea: Never Materialize the Full Matrix
Flash Attention, introduced by Tri Dao and collaborators in 2022, is built on one observation: you do not need the entire score matrix in memory at once to compute the final output. You can process the sequence in blocks, keep each block's intermediate results in fast SRAM, and accumulate the output incrementally. The full N-by-N matrix never exists in HBM.
The challenge is softmax. A normal softmax needs to see every score in a row before it can normalize, because the denominator is a sum over all of them, and the numerically stable version also needs the maximum score in the row. If you are only looking at one block of keys at a time, you do not have the full row yet. This is where the online softmax trick comes in.
The idea, which predates Flash Attention but is central to it, is that you can compute softmax in a streaming fashion. As you process each new block of keys and values, you keep a running maximum and a running sum, and you rescale the partial output you have accumulated so far to account for the new information. When a new block produces a larger maximum than you have seen, you scale down the existing running sum and the existing output by the appropriate factor, then add the new block's contribution. By the time you have processed every block, the running result equals exactly what the full softmax would have produced.
So the algorithm looks roughly like this for each query block:
1. Load the query block into SRAM.
2. Loop over key and value blocks. For each one, load it into SRAM, compute the block of scores, update the running maximum and running sum, and update the running output.
3. After the last key block, write the final output for this query block back to HBM.
The only thing written to HBM is the output O, which is the same size as the input. The score matrix S and the attention weights P are computed in small blocks inside SRAM and discarded as soon as their contribution is folded into the running totals. This is kernel fusion: the matrix multiply, the softmax, and the second matrix multiply are fused into a single GPU kernel rather than three separate ones, so there is no round trip to HBM between the steps.
## Why It Is Faster Even Though It Does More Math
There is a wrinkle worth understanding. Flash Attention actually recomputes some things. During the forward pass it does not store the full attention matrix, so during the backward pass, when gradients need those values, it recomputes the relevant blocks on the fly instead of reading them from memory. That is extra arithmetic.
It is still faster overall, and the reason is the memory-bound nature of the problem. The recomputed math runs on compute units that were otherwise idle waiting for memory. Trading cheap, abundant arithmetic for expensive, scarce memory bandwidth is a good trade on this hardware. The total number of HBM accesses drops from quadratic in the sequence length to roughly linear, and since HBM access was the limiting resource, the wall-clock time drops with it.
The memory savings are the other big win. Standard attention uses memory proportional to N squared because it stores the full score matrix. Flash Attention uses memory proportional to N, because it only ever holds blocks plus the running statistics. This is what makes long-context models practical. A 100,000-token context would need an attention matrix with ten billion entries per head under the naive approach, which simply does not fit. With Flash Attention the memory footprint stays manageable, and long context becomes a question of compute time rather than a hard memory wall.
## How v1, v2, and v3 Evolved
The original Flash Attention (v1, 2022) established the tiling and online-softmax approach described above. It delivered large speedups over standard attention and immediately found its way into training and inference stacks. Its main limitation was that it did not fully saturate the GPU. The way it partitioned work across the GPU's parallel units left some compute on the table, and the loop structure was not ideal.
Flash Attention v2 (2023) reworked the parallelization. The key changes were reducing the number of non-matmul operations (which run on slower units than the dedicated matrix-multiply hardware), parallelizing across the sequence-length dimension in addition to the batch and head dimensions, and improving how work is split between thread blocks and warps inside the GPU. The result was roughly a 2x speedup over v1, pushing attention to a much higher fraction of the GPU's theoretical peak throughput. For most people running on A100 or H100-class hardware today, v2 is the version doing the work under the hood.
Flash Attention v3 (2024) targets the Hopper generation of GPUs (H100 and similar) specifically. Newer GPUs added hardware features that v2 did not exploit, including asynchronous execution where the matrix-multiply units and the memory-movement units can run at the same time without waiting on each other, and native support for lower-precision formats like FP8. Flash Attention v3 overlaps computation and data movement more aggressively using these asynchronous instructions, and it adds an FP8 path for inference where the reduced precision buys additional throughput. On H100 hardware v3 reaches a substantially higher fraction of peak FLOPs than v2, and the FP8 mode roughly doubles throughput again for workloads that can tolerate the lower precision.
The throughline across all three versions is the same: keep the data in fast memory, fuse the operations, and adapt the work partitioning to whatever the current hardware does best.
## Where Flash Attention Helps Most
Not every workload benefits equally. The gains scale with sequence length, because the wasted memory traffic in standard attention grows quadratically with N. For short prompts the difference is modest. For long-context work, document processing, retrieval-augmented generation with large contexts, or any agent that accumulates a long history, the difference is large, both in speed and in whether the workload fits in memory at all.
During inference, attention shows up in two phases. The prefill phase processes the entire prompt at once and is compute-heavy, and Flash Attention helps it directly. The decode phase generates one token at a time, where the query is a single token attending to a growing cache of keys and values; here the operation is even more memory-bound, and variants of the Flash Attention idea adapted for decoding, sometimes called Flash-Decoding, parallelize over the key-value cache to keep the GPU busy when there is only one query. If you care about the decode side, it is worth reading our post on [KV cache compression](/blog/kv-cache-compression-mla-and-beyond) to see how the cache interacts with attention.
## Turning It On in PyTorch
The good news is that you usually do not implement any of this yourself. Modern frameworks expose Flash Attention through a single function, and in many cases it is already the default.
PyTorch 2.x ships `scaled_dot_product_attention`, which dispatches to a Flash Attention kernel automatically when the inputs and hardware support it:
```python
import torch
import torch.nn.functional as F
# q, k, v have shape (batch, heads, seq_len, head_dim)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
```
PyTorch picks the best available backend for your tensors and GPU. If you want to be explicit, you can scope which backends are allowed:
```python
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
```
For the upstream implementation with the newest kernels, the `flash-attn` package gives you the functions directly:
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func
# q, k, v have shape (batch, seq_len, heads, head_dim)
out = flash_attn_func(q, k, v, causal=True)
```
When loading a model through Hugging Face Transformers, you can request the implementation with a single argument:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"your-model",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
```
A few practical notes. Flash Attention kernels are written for half precision, so you generally need to run in `bfloat16` or `float16`; passing `float32` tensors will fall back to a slower path. Head dimensions and sequence shapes have some constraints depending on the version, though the common cases are all supported. And if you install `flash-attn` from source, expect a long compile, which is why the prebuilt wheels are worth using when they match your CUDA version.
## The Takeaway
Flash Attention is one of those rare optimizations that is both a large speedup and free, in the sense that it computes exactly the same mathematical result as standard attention. There is no approximation and no accuracy trade-off. It is simply a smarter way to schedule the computation around the realities of GPU memory hierarchy. That is why it went from a research paper to the default attention kernel in essentially every serious inference stack within a couple of years.
For anyone building on top of language models, the main thing to know is that it is there, it is usually on by default in current frameworks, and it is the reason long-context models are practical at all. At General Compute we run optimized attention kernels across our inference stack so you get the throughput and latency benefits without thinking about which version is dispatched. If you want to see what that looks like on real models, the fastest way is to point your existing OpenAI-compatible client at our [API](https://generalcompute.com) and measure the token rate yourself.