Back to blog
inferencepaperslong-contextdistributed

Ring Attention: Scaling Context to Millions of Tokens

General Compute·

Attention scales quadratically with sequence length in FLOPs and linearly in memory once you have a memory-efficient kernel like FlashAttention. The memory part is the harder constraint in practice. A single H100 has 80GB of HBM, and by the time you load model weights, optimizer states (during training), and the activations you need for backward, you run out of room for the KV cache long before you hit interesting context lengths. Doubling the context doubles the KV cache. At some point the query and key tensors just do not fit on one device.

Ring Attention, introduced by Hao Liu, Matei Zaharia, and Pieter Abbeel at Berkeley in 2023, solves this by spreading both the queries and the KV across a ring of devices, then rotating KV blocks around the ring while each device computes partial attention on whatever block it currently holds. The memory per device stays bounded, and the maximum context length scales linearly with the number of devices in the ring. With enough hardware, you get context windows of millions of tokens without materializing the full attention matrix anywhere.

This is one of the techniques that made million-token context windows practically possible. It is also a nice example of a distributed algorithm where the communication pattern matters as much as the arithmetic.

The Memory Wall for Long Context

Before getting into Ring Attention itself, it helps to be precise about what constrains context length on a single device.

For a transformer with hidden dimension d, sequence length N, and batch size B, the activations from a single attention layer are O(B * N * d) for the input and output tensors, plus whatever intermediate buffers the attention kernel needs. FlashAttention reduced the intermediate requirement from O(N^2) to O(N), which was the original breakthrough. But you still need to store Q, K, and V themselves, and you still need to hold the running output. If N is 1M tokens and d is 8192 with FP16, then each of Q, K, V is 1M * 8192 * 2 bytes = 16GB per layer, per batch item.

You cannot fit that on one GPU. Not even close, once you add the model weights.

The standard response is tensor parallelism and pipeline parallelism, which shard the model along the hidden dimension or across layers. Neither of these helps with the sequence dimension. If your problem is that a single sequence is too long for one device, splitting the model across devices does not buy you anything on the KV cache side.

Sequence parallelism is the obvious answer: split the sequence N across devices. But attention is not a pointwise operation. Every query token needs to see every key token, which means if you shard along N, each device needs access to the full K and V at some point during the computation. Naively, that means all-to-all communication, which is expensive.

Ring Attention is the clean way to structure that communication.

The Ring Attention Algorithm

Assume you have P devices arranged in a logical ring: device 0 talks to device 1, device 1 talks to device 2, and so on, with device P-1 wrapping back to device 0. Split the sequence of length N into P equal blocks. Device i holds block i of the queries Q_i, and initially also holds block i of the keys and values, K_i and V_i.

The algorithm proceeds in P rounds. On round t, each device i:

  1. Computes partial attention using its local Q_i against whatever K, V block it currently holds.
  2. Accumulates the partial output using an online softmax, just like FlashAttention does internally.
  3. Sends its current K, V block to the next device in the ring, and receives a new K, V block from the previous device.

After P rounds, each device has computed attention against every K, V block in the sequence, and the accumulated output on device i is the correct attention output for query block Q_i.

The trick that makes this fast is that the send and receive on step 3 happen concurrently with the compute on step 1. Modern GPUs have dedicated copy engines and NVLink or InfiniBand interconnects that can move data independently of the SMs doing math. If the compute for one block of attention takes roughly as long as transferring one K, V block to the next device, you get the communication essentially for free. The total runtime is dominated by compute, and the communication hides behind it.

The memory on any single device is bounded by two K, V blocks (the current one and the one being transferred) plus one Q block plus the accumulated output. That is O(N/P) total, which is exactly what you want.

Why It Works: Blockwise Attention Plus the Ring

Ring Attention is really the combination of two ideas:

Blockwise attention with online softmax. This is the part that lets you compute attention incrementally over chunks of K and V without ever materializing the full N by N score matrix. It is the same math that FlashAttention uses internally. You maintain running statistics (the max and the sum of exponentials) and correct the accumulated output as new blocks come in. The result is numerically identical to standard attention, not an approximation.

Ring topology for communication. The ring is the key to making the communication cost scale well. If you did the same computation with a broadcast or all-to-all, each device would need to receive K, V blocks from every other device in a short burst, which saturates the network. In a ring, each device only communicates with two neighbors at a time, and the total bandwidth used per step is constant regardless of how many devices you have. The time per round is a single K, V block transfer, and there are P rounds, so the total communication time is O(N), matching the compute work per device.

This is not just a convenient structure. It is the reason Ring Attention scales. If the communication cost grew with P, adding more devices to extend context would eventually stop helping.

Ring FlashAttention

The natural next step is to combine Ring Attention with FlashAttention. On each device, instead of using a standard attention kernel to compute the partial attention against the current K, V block, you use FlashAttention. This gives you the best of both: the intra-device computation is memory-efficient and tiled to fit in SRAM, and the inter-device computation is load-balanced across the ring.

The implementation needs a few tweaks. The FlashAttention kernel normally does its own online softmax internally and emits a final output plus a log-sum-exp statistic. When you are composing it across ring rounds, each round produces a partial output and its own log-sum-exp, and you combine them across rounds using the same online softmax correction that FlashAttention uses internally. So you end up with an online softmax nested inside an online softmax, which sounds terrible but is actually just careful bookkeeping.

Most production implementations of long-context serving use some variant of Ring FlashAttention. The ring handles the inter-device dimension, FlashAttention handles the intra-device dimension.

Striped Attention: The Load Balancing Problem

There is a subtle issue with the simple ring algorithm, which is that causal masking creates a load imbalance.

In a causal transformer (which is what decoder-only LLMs use), each query token can only attend to key tokens at positions less than or equal to its own. If you split the sequence into contiguous blocks, block i contains queries at positions [i*N/P, (i+1)*N/P). When block i computes attention against block j:

  • If j < i, the full block is below the diagonal, so every query-key pair is valid. This is a full compute load.
  • If j > i, the full block is above the diagonal, so no query-key pair is valid. The device does essentially no work.
  • If j == i, you are on the diagonal and half the pairs are valid.

The result is that each device does roughly half the work, but in a bursty pattern: the first few rounds do full work, the last few rounds do nothing, and the overall throughput is cut in half.

Striped Attention, a follow-up by William Brandon and others in 2023, fixes this by changing how the sequence is partitioned. Instead of giving each device a contiguous chunk of tokens, you interleave tokens so that each device gets every P-th token. Device i holds tokens at positions i, P+i, 2P+i, and so on. Now when device i computes attention against block j, the set of valid pairs is roughly the same regardless of which block j is, because you are always comparing interleaved slices that span the whole sequence.

The compute load per round becomes uniform, and you recover the full theoretical throughput.

In practice, you often want to apply striping at the block level rather than the individual-token level, because individual-token striping messes up things like rotary position embeddings and the memory access patterns of the underlying kernels. Block-level striping (give device i blocks 0, P, 2P, ... , block P-1 gets blocks P-1, 2P-1, etc.) gets most of the load balancing benefit without the complications.

Inference Versus Training

Ring Attention was originally pitched as a training technique, and the original paper's benchmarks focused on training throughput. But it is at least as important for inference, possibly more so.

During training, you have large batches and can usually absorb long sequences by shrinking the batch size or using gradient accumulation. The activation memory is the bottleneck, and Ring Attention addresses that.

During inference, the bottleneck is different. The KV cache grows linearly with context and has to be kept in memory across every decode step. For a single long-context request (one user asking a question about a million-token document), you often have no batch dimension to shrink. The KV cache alone for a 1M token context on a 70B model is about 340GB in FP16. You need to shard it across devices, and when you go to compute attention for the next decoded token, the query on one device needs to see the KV on all the other devices.

That is exactly the setup Ring Attention was designed for, just applied to decoding instead of prefill. The query tensor for a single decode step is tiny (one token worth of Q), so you can afford to broadcast or replicate it. The KV is the expensive part, and the ring handles the sharded compute naturally.

Some production serving systems use a hybrid: Ring Attention for prefill (where the query is long and the ring is well-balanced), and a different strategy for decode (where the query is one token and you can be cleverer about which devices need to participate at all). But the underlying primitive is the same.

What This Enables

Million-token context windows, like what Gemini 1.5 Pro shipped with, are not possible without something like Ring Attention under the hood. Google has not published the exact architecture, but the publicly available description of their approach makes it clear they are sharding sequences across devices with overlapped communication. Ring Attention is the reference algorithm for this class of technique, and the techniques described in the paper predate Gemini's long-context release.

The broader point is that context length is now primarily a systems problem, not an algorithmic one. The attention math itself has been well understood for years. What changed is how efficiently you can execute that math across a large fleet of accelerators, with the communication topology being as important as the arithmetic. Ring Attention is one of the cleanest examples of how to get the topology right.

Why Purpose-Built Hardware Matters Here

Ring Attention squeezes the most out of the hardware topology you have. But the underlying constraint is still memory bandwidth and interconnect bandwidth on chips that were not designed for this workload. NVLink and InfiniBand are fast, but they are general-purpose, and the choreography required to keep a ring saturated is fragile.

General Compute is the only neocloud built entirely on inference-optimized ASICs. The interconnect fabric between chips is designed specifically for the access patterns that long-context inference needs, including the streaming KV transfers that Ring Attention relies on. Combined with much higher on-chip memory per accelerator, the pressure on sequence parallelism is reduced in the first place. You need fewer devices in your ring, and the rounds are faster, so long-context serving is faster and cheaper end to end.

If you want to try running long-context inference on hardware that was purpose-built for it, sign up at generalcompute.com and get $5 in free credit to benchmark it against your current setup.

Papers and References

ModeHumanAgent