Back to blog
inferencepaperstransformerstoken-mergingpruningvision

Token Merging and Token Pruning for Faster Transformers

General Compute·

Attention is quadratic in sequence length. For a transformer block with N tokens, the attention matrix is N by N, and the compute and memory costs scale accordingly. Everything else in the block (the feedforward network, the projections, the layernorms) scales linearly with N. So if you want a transformer to run faster and you cannot change the hardware or the model weights, the most direct lever is to reduce N.

Token merging and token pruning do exactly that. They shrink the sequence in the middle of the network, after the model has already formed some view of which tokens matter, and let the remaining layers do less work. The techniques started in vision transformers, where adjacent patches often carry similar information, and they have since been adapted to language models, speech models, and multimodal systems. The accuracy cost is real but often small, and the speedup is immediate because it attacks the part of the cost function that grows the fastest.

This post walks through ToMe (the canonical token merging paper), compares it to the family of token pruning methods, and covers the issues that come up when you try to apply these ideas to LLM inference rather than vision classification.

The Sequence Length Problem

In a standard transformer layer, the attention pass is O(N^2 * d) and the feedforward pass is O(N * d^2), where d is the hidden dimension. For short sequences, feedforward dominates because d is usually much larger than N. For long sequences, attention dominates because N^2 overtakes N * d. The crossover point is roughly when N is on the order of d.

Modern LLMs push this crossover in both directions. Models have grown wider (d increases) but context windows have grown faster (N increases from 2K to 128K to 1M). In vision, a 224x224 image with a 16x16 patch size gives 196 tokens, and a 336x336 image gives 576. Video transformers easily reach tens of thousands of tokens once you add a temporal dimension.

Once N is large enough that attention dominates, the payoff for reducing N is direct. Cutting the sequence in half roughly quarters the attention cost of the remaining layers. Most of the reductions in the literature target 30 to 50 percent fewer tokens by the middle of the network, with the intent of keeping enough information that the final prediction is still correct.

What ToMe Actually Does

ToMe (Token Merging) was introduced by Bolya et al. in 2022 for vision transformers. The core idea is that many adjacent tokens in a ViT carry redundant information, so merging them pairwise between attention blocks reduces the sequence without training a new model or adding new parameters.

The mechanism has three pieces.

The first is a similarity score. After attention, each token has a key vector that the model already computed. ToMe uses the cosine similarity between pairs of keys as the matching signal. This is free: the keys are already there, and the similarity is a cheap dot product followed by a norm.

The second is a bipartite matching step. ToMe splits the tokens alternately into two sets (A and B), finds the best match for each token in A within set B using the similarity scores, and picks the top r pairs with the highest similarity. The r parameter is the merging budget for that layer, and it is the main knob that trades speed for accuracy. Setting r to 8 means merge 8 pairs per block, removing 8 tokens per block.

The third is the merge itself. Each selected pair gets averaged, with a small weighting adjustment so that merged tokens do not dominate subsequent attention passes. The averaging uses a running count of how many original tokens each merged token represents, and the attention scores are scaled to keep the probability mass correct. This is the "proportional attention" fix, and it matters more than it sounds like it should, because without it the merged tokens punch above their weight and accuracy drops.

ToMe is inserted between blocks, not inside them. A ViT-L/16 with ToMe at r=8 per block and 24 blocks ends up with roughly half as many tokens at the output as at the input. On ImageNet classification, ToMe reports around 2x throughput with about 0.4 percent top-1 accuracy drop on ViT-L, and the drop shrinks further when the model is fine-tuned with ToMe active during training.

The paper also shows that ToMe composes with other acceleration methods. It does not change the model weights, does not change the training loss, and does not require the architecture to expose any special hooks. You insert it, pick r per block, and you get a faster model.

Token Pruning: The Other Half of the Family

Token pruning takes a more aggressive position. Instead of merging tokens it thinks are similar, it drops tokens it thinks are unimportant. The distinction matters. Merging preserves information from every input token (in aggregate), while pruning discards information entirely. This lets pruning reach larger speedups at the cost of more careful selection.

The canonical vision version is DynamicViT (Rao et al., 2021), which adds a small prediction head inside each block that assigns an importance score to each token. Tokens below a threshold are masked out of subsequent layers. DynamicViT needs training with a distillation loss to teach the predictor which tokens are safe to drop, which makes it less plug-and-play than ToMe but also gives it more control over what gets kept.

Other vision pruning methods use attention scores directly. EViT (Liang et al.) keeps the top-k tokens by attention weight from the CLS token and fuses the rest into a single pooled token. A-ViT (Yin et al.) uses the model's own halting scores. The variations mostly differ in how they score tokens and whether they drop or fuse the unimportant ones.

A useful way to think about the family: merging assumes redundancy, pruning assumes irrelevance. In a densely packed image, most adjacent patches are redundant with their neighbors, so merging works well. In a sparse input (say, a medical image where most of the frame is empty tissue and only a small region matters), pruning works better because the irrelevant tokens really are irrelevant.

Moving to Language Models

Applying token merging or pruning to language transformers is harder than applying it to vision. The reasons are structural, and they explain why LLM serving has adopted these ideas slowly compared to ViTs.

First, language tokens are less redundant than image patches. Two adjacent patches in an image are often nearly identical. Two adjacent tokens in a sentence rarely are. Merging "the" and "cat" into an average representation loses the distinction between them in a way that averaging two neighboring pixel patches does not.

Second, autoregressive generation is causal. In a ViT, all tokens attend to all other tokens, so dropping token i from the sequence just removes one node from the graph. In a decoder, every future token attends to every past token, so dropping token i changes the attention output of every token j where j > i. The effect compounds across layers, and the model's behavior at generation time diverges from what it saw during training.

Third, the KV cache changes the accounting. In vision, reducing the sequence reduces both the forward pass cost and the memory needed to hold activations. In LLM serving, the dominant cost at long context is the KV cache, not the forward pass. A method that reduces the working sequence at each attention pass but does not remove entries from the KV cache does not save memory, just compute.

There are now several approaches that address these constraints.

LazyLLM (Apple, 2024) introduced dynamic token pruning specifically for LLM inference. The idea is to skip attention computation for tokens that the model does not currently need to attend to, based on attention scores from earlier layers. LazyLLM does not remove tokens from the KV cache permanently; it just skips them during the forward pass for a given step. If a later step needs them, they can be reactivated. This matches the prefill/decode asymmetry well: during prefill, most tokens are important, but during decode, only a small subset of past tokens usually contribute meaningfully to the next token's attention.

PyramidInfer (2024) reduces the KV cache size layer by layer, keeping more tokens in the earlier layers and fewer in the later ones. This is shaped like a pyramid, hence the name. The motivation is that the later layers' attention is often concentrated on a small number of tokens, while earlier layers spread their attention more broadly. Keeping the full KV cache in early layers and a pruned version in later layers saves memory without much accuracy cost.

H2O (Heavy-Hitter Oracle) and similar methods prune the KV cache directly based on accumulated attention weights. A token that has been attended to many times is a "heavy hitter" and gets kept. A token that has rarely been attended to gets evicted. These methods are more aggressive and can reduce KV cache size by 50 percent or more on long-context workloads, though accuracy on needle-in-a-haystack tasks usually suffers if the pruning is too aggressive.

StreamingLLM (Xiao et al.) takes a different angle: keep the first few "attention sink" tokens plus a sliding window of recent tokens, drop everything else. This is pruning as an architectural choice rather than a learned one. It works for infinite-context streaming but loses information from the middle of the sequence, so it is not a drop-in replacement for full attention on retrieval tasks.

Composition with Other Inference Techniques

Token reduction methods compose well with most inference optimizations, but the interactions are worth thinking about.

With paged attention and KV cache management, token pruning has to know whether it is removing tokens from the working set (cheap) or evicting them from the cache (requires recomputation if the tokens come back). Most modern implementations distinguish these cases explicitly.

With speculative decoding, token reduction in the verifier model has to stay conservative enough that the verifier's scoring of draft tokens does not drift from what the full model would produce. A too-aggressive pruner can cause acceptance rates to drop, which eats into the speculative speedup.

With quantization, the effects are roughly additive. A quantized model with token merging gets both the per-operation speedup of quantization and the sequence-length speedup of merging. The accuracy cost is also roughly additive, so the total drop is larger than either alone. In practice, people tune them together rather than stacking them independently.

With continuous batching, token reduction shrinks the per-request compute and memory footprint, which lets more requests fit in the same batch. This is usually the easiest way to see the benefit in a serving system: not a change in single-request latency, but a change in how many concurrent requests the engine can sustain.

When Token Reduction Is Worth It

Token merging and pruning are worth it when sequence length is the dominant cost. For short-context chat (a few hundred tokens in, a few hundred out), the feedforward and projection costs are already much larger than attention, and saving tokens does not save much overall. For long-context workloads (summarization, document QA, long agentic traces, vision-language models with many image tokens), attention dominates and the savings show up clearly.

They are also worth it when you can afford a small accuracy regression and cannot afford to retrain the model. ToMe famously works without retraining, and most LLM-side methods (LazyLLM, PyramidInfer) are training-free or nearly so. If you have the budget to fine-tune with the method active, accuracy recovers substantially, which is why production deployments often do a short calibration run.

They are not worth it when the workload is short-context and compute-bound, or when the accuracy budget is very tight. On a 1K-context chat model, token merging might save 10 percent latency at a 0.5 percent accuracy cost. On a 128K-context document QA model, the same method might save 40 percent latency at the same accuracy cost. The payoff scales with N.

The Practical Picture

Across both vision and language, the lesson is the same. Transformers do not need all their tokens to make their final prediction, and most of the time, the tokens they do need can be identified cheaply. Cosine similarity on keys works. Attention-score heuristics work. Small learned predictors work. The specific choice depends on the model and the workload, but the existence of the redundancy is now well established.

For a serving stack, the main question is where to apply the reduction: in the forward pass, in the KV cache, or both. ToMe and its descendants reduce the forward pass. H2O and StreamingLLM reduce the cache. PyramidInfer does both, at different rates per layer. Getting these choices right usually involves profiling the specific model and workload, because the right r per block or the right eviction threshold is not universal.

At General Compute, our infrastructure focuses on the primitives that sit underneath these techniques: fast attention kernels, efficient KV cache management, and schedulers that understand variable sequence lengths. Token merging and pruning are less commonly exposed as a user-level feature because they change model behavior in ways developers usually want to control themselves, but the underlying engine is built to support them cleanly when a workload needs them. If you are running long-context inference and want to experiment with this class of optimizations, the API and docs are the place to start.

ModeHumanAgent