Multi-Query and Grouped-Query Attention: Shrinking the KV Cache
Standard multi-head attention (MHA) gives each attention head its own set of keys and values. For a model with 32 heads, that means 32 separate key tensors and 32 separate value tensors stored in the KV cache (the memory that holds the model's "working memory" of the conversation). During generation, every single one of those tensors needs to be read from memory for every token produced.
This is a lot of memory traffic for information that's often highly redundant across heads. Two papers showed that you can share key-value heads across multiple query heads with minimal quality loss, dramatically reducing the KV cache size and speeding up inference.
The Memory Bandwidth Problem in Decoding
During autoregressive decoding (generating one token at a time), the model needs to read the entire KV cache for all previous tokens at each step. For a model like Llama 2 70B with 64 attention heads and a 4096-token context, the KV cache is roughly 2.5GB. Every single token generation step requires reading all of that from GPU high-bandwidth memory (HBM).
With standard multi-head attention, the KV cache size scales as: num_layers x num_heads x 2 (one for K, one for V) x sequence_length x head_dimension x bytes_per_element. The num_heads term is the target of MQA and GQA.
If you could reduce the number of KV heads without hurting model quality, you'd directly reduce memory bandwidth requirements during decoding, which is the primary bottleneck for inference speed.
Multi-Query Attention (MQA)
Multi-Query Attention was proposed by Noam Shazeer (one of the original Transformer paper co-authors) back in 2019. The idea is radical in its simplicity: instead of giving each attention head its own keys and values, use a single key head and a single value head shared across all query heads.
Each query head still computes its own unique attention pattern (so the model can still attend to different things from different perspectives), but they all attend over the same set of keys and values.
The impact on the KV cache is dramatic. For a model with 64 heads, MQA reduces the KV cache size by 64x. That's the difference between a 2.5GB KV cache and a ~40MB one.
In practice, the quality impact is small but measurable. MQA was adopted by several major models including PaLM (Google's large language model) and Falcon. The tradeoff was considered worthwhile because the inference speedup is enormous, especially for long sequences where the KV cache dominates memory usage.
Paper: "Fast Transformer Decoding: One Write-Head is All You Need" (Shazeer, 2019)
Grouped-Query Attention (GQA)
Grouped-Query Attention (Ainslie et al., Google, May 2023) is the middle ground between standard multi-head attention and multi-query attention. Instead of one KV head shared by all queries (MQA) or one KV head per query (MHA), GQA uses an intermediate number of KV head groups.
For example, a model with 32 query heads might use 8 KV head groups, so each group of 4 query heads shares one set of keys and values. This gives you an 4x reduction in KV cache (compared to MHA) while keeping quality closer to the full multi-head version.
The paper also showed something practically useful: you can take an existing model that was trained with multi-head attention and "uptrain" it (continue training for a short period) to use grouped-query attention, using only about 5% of the original pre-training compute. You don't need to train from scratch.
Results: GQA-8 (8 KV groups) achieves quality close to full MHA while running at speeds close to MQA. The authors showed this on a 150B parameter model derived from a T5-XXL checkpoint.
Adoption: GQA was quickly adopted by the industry. Llama 2 70B, Llama 3 (all sizes), Mistral, and most modern open-source models use GQA. It's become the default attention configuration for new models because the quality-speed tradeoff is so favorable.
Paper: "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints" (Ainslie et al., 2023)
How MQA and GQA Affect Inference
The practical impact on inference systems:
Smaller KV cache = more concurrent requests. With GQA-8 instead of MHA-32, the KV cache is 4x smaller per request. On a GPU with fixed memory, this means you can serve 4x more concurrent users (or handle 4x longer contexts) before running out of memory. This interacts directly with PagedAttention (covered in our vLLM post), since the pages are smaller.
Faster decoding. Each decode step reads less data from memory. For memory-bandwidth-bound workloads (which is essentially all autoregressive decoding), less data to read means faster generation. The speedup is roughly proportional to the reduction in KV heads.
Longer contexts become practical. At 128K tokens, the KV cache for a 70B model with full MHA would be enormous. GQA makes long-context inference feasible by keeping the cache manageable.
Works with everything else. MQA and GQA are architectural choices made during model design, and they're fully compatible with all the other inference optimizations: FlashAttention, speculative decoding, continuous batching, quantization, and so on. The benefits stack.
The Design Space
It's worth noting where MQA/GQA sit in the broader design space of "how do we make the KV cache smaller":
- MQA/GQA reduce the KV cache by sharing heads at the architectural level. Requires the model to be trained (or uptrained) with the configuration.
- KV cache quantization reduces the cache by storing values in lower precision (FP8 or INT4 instead of FP16). Can be applied post-training.
- KV cache eviction (H2O, StreamingLLM) reduces the cache by dropping old or unimportant tokens. Applied at serving time.
- Multi-head Latent Attention (MLA) from DeepSeek compresses the KV cache into a low-dimensional latent vector. Requires architectural changes during training.
These approaches are complementary. A model using GQA can also have its KV cache quantized to FP8 and use eviction policies for very long contexts. The reductions multiply together.
Why This Matters on ASICs
MQA and GQA were designed to reduce memory bandwidth pressure on GPUs, where reading the KV cache from HBM is the primary bottleneck during decoding. On inference-optimized ASICs, the memory architecture is fundamentally different, with much higher bandwidth relative to compute and memory layouts designed specifically for the access patterns that attention uses.
General Compute is the only neocloud running entirely on inference-optimized ASICs rather than NVIDIA GPUs. The KV cache efficiency improvements from GQA still help on our hardware (smaller caches are always better), but the baseline memory bandwidth is so much higher that the gap between GQA and full MHA is smaller than it would be on GPUs. Our infrastructure is fast with either configuration, while GPU-based providers rely heavily on GQA to make decoding workable.
Sign up at generalcompute.com and get $5 in free credit to try it out.
Papers and References
- Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023 -- EMNLP 2023)