Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Attention kernels: from PagedAttention to FlashAttention, FlashInfer, and MLA

In Chapter 4 we made the KV cache paged. We chopped each sequence’s keys and values into fixed 16-token blocks, scattered those blocks anywhere in a global pool, and kept a per-request block table mapping logical positions to physical blocks. That bought us the concurrency that makes a replica economical. But it also quietly broke something. A textbook attention kernel wants its keys and values laid out contiguously: one flat [seq_len, num_heads, head_size] tensor it can stream linearly. Paging hands the kernel the opposite of that — a sequence whose KV lives in blocks that are physically discontiguous and, across a batch, of wildly different lengths. The layout is only half the design. The other half is a kernel that can attend over ragged, block-scattered sequences and still hit the memory-bandwidth roofline we established in Chapter 3.

This chapter is about that kernel, and about where it lives in vLLM. The punchline up front: there is no single attention kernel. There is a dispatch seam — an opaque custom op — and behind it a zoo of backends, each tuned for a hardware generation, a precision, or a model architecture. The most consequential of those architectures is Multi-head Latent Attention, which doesn’t just need a different kernel. It changes the KV cache you sized in Chapter 4 into a single compressed latent vector per token, and it’s how the flagship 2026 open models actually run.

Before diving in, it helps to fix the vocabulary the rest of the chapter leans on. Attention is the operation where each query token computes a weighted sum over the keys and values of every token it can see; the weights come from a softmax over query-key dot products. Prefill is the phase that processes all the prompt tokens at once (many queries), and decode is the phase that generates one token at a time (one query attending over a long history). A kernel is a single GPU program launch; the cost we keep returning to is whether that launch is bandwidth-bound (limited by how fast it can read the KV cache out of memory) or launch-bound (limited by the fixed CPU and scheduling overhead of starting the kernel at all). The whole chapter is the story of how vLLM keeps the attention kernel bandwidth-bound while feeding it the awkward, paged, ragged input that real serving produces.

The dispatch seam

Start with the indirection, because it explains everything downstream. A model’s attention layer never calls FlashAttention, or FlashInfer, or any kernel by name. It calls a registered PyTorch custom op — a function that PyTorch knows about by name but whose insides are sealed off from the compiler. Here is the op, in vllm/model_executor/layers/attention/attention.py:

def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: LayerNameType,
    ...
) -> None:
    ...
    layer_name = _resolve_layer_name(layer_name)
    attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)

    self.impl.forward(
        self, query, key, value, kv_cache, attn_metadata, output=output, ...)

Two things matter here. First, the op is registered with direct_register_custom_op (same file), which means torch.compile treats it as a black box — it will not trace into it. That opacity is exactly what lets the compiler split the model graph at attention boundaries, which is the mechanism Chapter 10 uses to capture everything between attentions as CUDA graphs while leaving attention itself eager. The dispatch seam and the cudagraph story are the same seam. Second, self.impl is a concrete backend object: it is the actual kernel wrapper, resolved exactly once, at startup, by get_attn_backend in vllm/v1/attention/selector.py. Think of self.impl.forward as a function pointer that was bound long before the first token ever arrived, so the per-step hot path pays no selection cost — it just calls through.

How does that binding get made? get_attn_backend builds an AttentionSelectorConfig describing the shape of the problem — the head size, the dtype, the KV-cache dtype — together with a pile of feature flags that capture which kind of attention this is: use_mla (is this a latent-attention model?), has_sink (does it use attention sinks?), use_sparse (top-k sparse attention?). It hands that config to the platform layer, which knows the GPU generation and returns a class path — a string naming the backend to load. The names live in an enum in vllm/v1/attention/backends/registry.py:

    FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
    ...
    FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
    ...
    TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
    CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
    FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"

That is the whole map. A name, a class path, lazily imported. The selection is cached (@cache on _cached_get_attn_backend), and the chosen backend can even demand a specific physical KV layout — get_required_kv_cache_layout() may call set_kv_cache_layout, so the kernel reaches back and dictates how Chapter 4’s blocks are arranged in memory. The kernel is not a passive consumer of the cache. It co-designs it.

The diagram below traces the two timescales of this seam. The left column is what happens once, at startup: the selector reads the config, picks a backend, and binds it into self.impl. The right column is what happens on every forward pass: the model calls the custom op, which looks up the already-bound backend and dispatches into it. Notice that the only edge crossing from startup into the hot path is the binding of self.impl — everything expensive about selection has been hoisted out of the per-step loop.

flowchart TD
    subgraph startup["once, at startup"]
        cfg["AttentionSelectorConfig (head size, dtype, kv dtype, use_mla, has_sink, use_sparse)"] --> sel["get_attn_backend in selector.py"]
        sel --> plat["platform layer picks a class path from registry.py"]
        plat --> impl["self.impl = concrete backend object"]
    end
    subgraph hot["every forward pass (eager)"]
        layer["model attention layer"] --> op["unified_attention_with_output (custom op, opaque to torch.compile)"]
        op --> ctx["get_attention_context: look up attn_metadata + kv_cache"]
        ctx --> call["self.impl.forward(...)"]
    end
    impl -. binds .-> call
    call --> backend["chosen kernel: FlashAttention / FlashInfer / one of the MLA backends"]

FlashAttention: where “paged” actually lives now

The default backend on NVIDIA is FlashAttention. Its job each step is two-phase: first write this step’s freshly-computed keys and values into the paged cache so they are visible to future tokens, then attend over the whole cache. Why two phases and not one? Because the keys and values a token produces this step must be stored before that same step’s query reads them, and every later step will reread them too. The write is what makes the cache grow by one token’s worth of state per step; the attend is what consumes the whole accumulated cache. In current vLLM these are two separate custom ops — the write was pulled out into its own do_kv_cache_update method. Since torch.compile cannot see inside either op, it has no way to know that the attend depends on the write having happened first, so vLLM threads a dummy tensor argument (kv_cache_dummy_dep) from the write into the attend purely to manufacture that data dependency and pin the ordering. The work is the same as a single fused phase; the split just makes the ordering explicit to the compiler.

The write is a scatter, not a copy. A copy would lay the new tokens down contiguously; a scatter sends each token to a different, possibly far-apart, physical slot, because paging means consecutive logical positions can live in any physical block. Here it is, in vllm/v1/attention/backends/flash_attn.py:

        reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

slot_mapping is the per-token list of physical slots produced by the block table from Chapter 4 — for each token in this batch, the single integer address where its KV entry belongs in the global pool. The op reshapes the new K/V and writes each token to its slot, fusing in the FP8 KV-cache scaling we’ll meet in Chapter 12 (the _k_scale and _v_scale arguments). Note the comment in the source: it deliberately does not slice key[:num_actual_tokens], because the op infers the token count from slot_mapping’s shape — a micro-optimization that tells you how cost-sensitive this path is.

Then the attention itself:

                flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    ...
                    block_table=block_table,
                    ...
                )

Source: vllm/v1/attention/backends/flash_attn.py

This single call is where PagedAttention lives in 2026. There is no separate “paged attention kernel” on the hot path anymore. flash_attn_varlen_func is FlashAttention’s variable-length kernel, and two of its arguments do the heavy lifting. The first is cu_seqlens_q, the cumulative sequence lengths of the queries — this is what “varlen” means. Imagine three requests in a batch with query lengths 7, 1, and 2000. Rather than pad them all to 2000 and waste the kernel’s time on padding, vLLM concatenates them end to end into one flat buffer of length 2008 and hands the kernel the running offsets [0, 7, 8, 2008]. The kernel reads those offsets to know where each sequence starts and stops, so a length-7 sequence and a length-2000 sequence ride the same launch with zero padding waste. The second is block_table, the per-request mapping from logical block to physical block from Chapter 4; passing it in is what makes the kernel walk the scattered cache instead of a contiguous tensor. Paging became a feature of the FlashAttention kernel rather than a kernel of its own. This is the practical payoff of the FlashAttention-2 work (arXiv:2307.08691): by reorganizing the attention computation to maximize work per memory pass and parallelize across sequence length and the batch, it keeps the kernel bandwidth-bound rather than launch-bound, which is precisely what a continuous-batching engine that re-plans every step needs.

The varlen formulation also explains why the scheduler from Chapter 5 could collapse prefill and decode into one batch: a decode step is just a varlen sequence with query length 1, and a prefill chunk is one with query length N. The kernel does not care which is which. It reads cu_seqlens_q and goes.

The diagram below traces a single FlashAttention step end to end. Follow the new query/key/value tensors in from the top: the keys and values branch off to be scattered into the paged cache via slot_mapping, and then the attend phase reads the query plus the entire updated cache — packed by cu_seqlens_q, walked through block_table — to produce the output. The cache write must finish before the attend reads, which is the ordering the two-op split exists to guarantee.

flowchart TD
    qkv["new q, k, v for this step's tokens"] --> q["query buffer"]
    qkv --> kv["key, value buffers"]
    kv --> scatter["reshape_and_cache_flash: scatter each token to its slot"]
    slot["slot_mapping: physical slot per token"] --> scatter
    scatter --> cache["paged KV cache (updated)"]
    q --> attend["flash_attn_varlen_func"]
    cache --> attend
    cu["cu_seqlens_q: pack ragged batch, no padding"] --> attend
    bt["block_table: logical to physical blocks"] --> attend
    attend --> out["attention output"]

One warning in this file is worth internalizing, because it reframes how you read every backend forward in vLLM:

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.

Source: vllm/v1/attention/backends/flash_attn.py

Because attention runs eager in the gaps between captured graphs (Chapter 10), the Python in this method is on the critical path of every decode step. A stray .view() that touches no GPU memory can still cost you tail latency at high QPS. This is the kind of thing a distributed-systems engineer recognizes instantly: it is the per-request CPU overhead problem, just relocated into a kernel wrapper.

The original PagedAttention CUDA kernels still exist — csrc/libtorch_stable/attention/paged_attention_v1.cu and paged_attention_v2.cu, both Apache-licensed and carrying their FasterTransformer lineage in the header — but they are off the main path now, kept as a reference implementation and a fallback. Reading paged_attention_v2.cu is still worthwhile for your performance instincts: v2 splits each sequence’s KV across thread blocks and does a second reduction pass, the classic split-K trick for keeping the GPU busy when a single long sequence would otherwise starve it. That is the same problem FlashAttention-2 solves more generally.

FlashInfer and cascade attention

FlashInfer is the alternative NVIDIA backend, and the reason to reach for it is in vllm/v1/attention/backends/flashinfer.py. It carries a special metadata field:

    use_cascade: bool
    """
    If True, the entire batch is a cascade attention call, and the
    `prefill` and `decode` fields will both be None.
    """

Cascade attention is the kernel-level partner of the prefix caching from Chapter 7. When many requests in a batch share a long common prefix — a system prompt, a few-shot block — naive attention reads that shared KV once per request. If sixteen requests all share a 1000-token system prompt, that prompt’s keys and values get streamed out of memory sixteen times per step, and since decode is bandwidth-bound (Chapter 3) that rereading is the cost.

The curve below makes the divergence concrete for that 1000-token shared prefix: naive attention rereads the prefix once per request, so its per-step prefix traffic grows linearly with batch size, while cascade reads the prefix exactly once no matter how wide the batch fans out. At a batch of sixteen that is the difference between 16000 and 1000 prefix-token reads per step.

Illustrative: counts follow exactly from the chapter’s 1000-token-prefix example (naive = prefix x batch, cascade = prefix); a real batch’s prefix length and fan-out vary.

Cascade fixes this by exploiting a property of softmax attention: an attention result over a long context can be computed in pieces and stitched back together exactly, as long as you carry along each piece’s softmax normalization statistics (its running max and sum) so the merge reweights the pieces correctly. This is the same online-softmax trick that makes FlashAttention itself tile-able. Cascade applies it across the batch dimension. It splits the computation into two levels. Level one attends every query in the batch over the shared prefix, reading that prefix’s KV exactly once for the whole batch. Level two attends each request’s queries over only that request’s unique suffix. Then it merges the two partial results with a numerically-correct softmax combine. The shared prefix went from sixteen rereads to one.

At dispatch the kernel is invoked through a planned wrapper:

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

Source: vllm/v1/attention/backends/flashinfer.py

The diagram below contrasts the naive path with the cascade path for a batch that shares a prefix. The point to read off it is the fan-in on the shared-prefix box: in the naive case every request’s attend reaches into it (three arrows, three rereads), while in the cascade case a single level-one attend reads it once and feeds the merge.

flowchart LR
    subgraph naive["naive: shared prefix reread per request"]
        r1["request 1 query"] --> p1["attend over prefix + suffix 1"]
        r2["request 2 query"] --> p2["attend over prefix + suffix 2"]
        r3["request 3 query"] --> p3["attend over prefix + suffix 3"]
        p1 --> pref["shared prefix KV in memory"]
        p2 --> pref
        p3 --> pref
    end
    subgraph cascade["cascade: shared prefix read once"]
        allq["all queries in batch"] --> shared["level 1: attend over shared prefix once"]
        q1["request 1 query"] --> s1["level 2: attend over suffix 1"]
        q2["request 2 query"] --> s2["level 2: attend over suffix 2"]
        q3["request 3 query"] --> s3["level 2: attend over suffix 3"]
        shared --> merge["softmax-correct merge of partial results"]
        s1 --> merge
        s2 --> merge
        s3 --> merge
    end

The “rare case” comment is honest: cascade only pays off when the shared prefix is long and the batch fan-out is wide, and the metadata builder gates it on a measured common_prefix_len. It is honest in another way too — at the time of writing the FlashInfer backend’s use_cascade_attention actually returns False with a “doesn’t work, disable it for now” note, so the path is dormant in practice. But the mechanism is the point, and when it does fire it turns prefix sharing from a memory win (Chapter 7 saved you the recompute) into a bandwidth win (cascade also saves you the reread). This is the broader thesis of the FlashInfer paper (arXiv:2501.01005): attention serving has many shapes — paged decode, ragged prefill, shared-prefix cascade — and a serving engine wants a single customizable engine that compiles the right kernel for each, rather than a pile of hand-written specializations. Read it for the block-sparse formulation that unifies paging and cascading under one abstraction.

MLA: the KV cache becomes a latent vector

Everything so far assumed the KV cache you sized in Chapter 4: two tensors, keys and values, one entry per attention head per token. For a model with many KV heads, that cache is large, and since decode is bandwidth-bound (Chapter 3), every byte of it is reread on every step. The bytes are the bottleneck, so the obvious lever is to store fewer of them.

Multi-head Latent Attention, introduced with DeepSeek-V2 (arXiv:2405.04434), does exactly that, and it attacks the cache itself rather than the kernel around it. The key idea is that the per-head keys and values, across all heads, are highly redundant — they can be reconstructed from a much smaller shared representation. So instead of caching per-head keys and values, MLA caches a single low-rank latent vector per token. “Low-rank” here means the latent lives in a much smaller dimension than the full per-head K and V it stands in for. When attention needs the actual per-head K and V, the kernel reconstructs them on the fly by multiplying the latent through fixed up-projection matrices (matrices learned at training time that expand the compact latent back out to full head dimensions). The trick that makes this cheap at serve time is that those up-projection matrices can be algebraically folded into the query and output projections, so the kernel often never has to materialize the full K and V at all.

The cleanest evidence is the cache shape. Recall the FlashAttention layout was [num_blocks, 2, block_size, num_kv_heads, head_size]. The MLA layout, from vllm/model_executor/layers/attention/mla_attention.py:

    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, block_size, head_size)

There is no 2 (no separate K and V — the latent carries both) and no head dimension (num_kv_heads is 1). And the head_size is not a head size in the usual sense:

        self.head_size = kv_lora_rank + qk_rope_head_dim

Source: vllm/model_executor/layers/attention/mla_attention.py

So a cached MLA token is the latent itself (kv_lora_rank, the compressed part) plus a small slice that carries rotary position information (qk_rope_head_dim), packed into one vector. For DeepSeek-V3 those are 512 and 64, so each cached token is a single $512 + 64 = 576$-element vector. Contrast that with a conventional model caching, say, 8 KV heads of 128 dims each across two tensors — $2 \times 8 \times 128 = 2048$ elements per token. Against this modest 8-head baseline that is already a 3.6x shrink per token, and against a wide-headed model it reaches the order of magnitude that makes MLA worth the trouble. This is exactly the forward-pointer Chapter 4 flagged: not every model has a fat growing KV cache. MLA’s is compressed, and because decode rereads the whole cache every step, a smaller cache is a proportionally cheaper reread — that directly buys back the concurrency and context length that decode bandwidth was spending.

The curve below traces per-sequence cache size (one attention layer, fp16) against context length for both layouts. Because cache size is linear in context length, MLA’s smaller per-token footprint shows up as a flatter line: at 128k tokens this 8-head baseline holds roughly 524 MB per layer where MLA holds about 147 MB, and the gap only widens with wider conventional heads.

Illustrative: bytes computed from the chapter’s element counts (2048 vs 576 elements/token x 2 bytes fp16) for a single attention layer; the conventional curve scales with that model’s KV-head count.

The cost is moved into compute, and the design note at the top of mla_attention.py is unusually candid about the tradeoff. MLA has two computation paths:

MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach. We generally want to use the compute friendly
approach for "prefill" ... and the data-movement friendly approach for "decode"

Why split into two paths at all? Because prefill and decode sit on opposite sides of the roofline from Chapter 3, and the latent gives you a knob that lets you choose where the work lands. In prefill you have many query tokens against a comparable number of keys, so there is plenty of arithmetic to amortize a memory pass — you are compute-bound. There it pays to up-project the latents back to full per-head K and V once and run ordinary multi-head attention; this is the compute-friendly path, and “data-movement friendly” would be wasted because data movement is not the bottleneck. In decode you have one query token attending over a long context, so there is very little arithmetic and the cost is dominated by reading the cache — you are bandwidth-bound. There it pays to not up-project: keep everything in the compressed latent space and run what amounts to multi-query attention with one fat head. That spends more FLOPs per token (the up-projection math gets folded into the per-token work) but touches far fewer bytes, which is exactly the right trade when bytes are the bottleneck. The same model thus uses two different kernel strategies depending on which phase the scheduler put a token in.

The diagram below shows both paths diverging from the one shared latent cache. Read it as a fork: the cache is identical for both phases, but prefill walks left into the up-project-then-attend path while decode walks right into the stay-compressed path.

flowchart TD
    cache["latent KV cache: one 576-dim vector per token (kv_lora_rank + qk_rope_head_dim)"] --> fork{"which phase?"}
    fork -->|"prefill: many queries, compute-bound"| up["up-project latent to full per-head K, V"]
    up --> mha["ordinary multi-head attention (compute-friendly)"]
    fork -->|"decode: one query, bandwidth-bound"| keep["keep latent compressed, fold up-projection into the math"]
    keep --> mqa["multi-query attention with one fat head (data-movement-friendly)"]
    mha --> out["attention output"]
    mqa --> out

That is why the mla/ directory is a family of backends — flashmla.py, cutlass_mla.py, flashinfer_mla.py, triton_mla.py — each pairing a prefill path with a decode path tuned for a specific GPU. The selector’s use_mla flag is what routes a DeepSeek model down this entire parallel universe.

There is one more turn of the screw. The 2026 flagship models push toward sparse MLA. The observation is that even after compressing each token, decode still rereads every token’s latent on every step, and most of those tokens contribute almost nothing to the result. So instead of attending over the whole context, a lightweight indexer scores the cached tokens against the current query and selects only the top-k most relevant ones — say the 2048 that matter out of a 128k context — and attention runs over just those. vLLM has backends for it — flashmla_sparse.py, flashinfer_mla_sparse.py — fed by an indexer.py (which produces the per-request top-k choices) and the Triton helpers in sparse_utils.py that translate those per-request top-k indices into global cache positions the kernel can fetch. The FlashMLA backend even reorders the batch around this: reorder_batch_threshold: int = 128 # process small prefills with decode pathway, a comment that captures how blurry the prefill/decode line has become once the kernel, not the scheduler, owns the distinction. Sparse attention over a compressed cache is a credible answer to long-context decode cost, and it is still actively moving — the right kernel and the right top-k policy are not settled.

What is solved, what is not

Step back and the shape is clear. Paged layout (Chapter 4) and the attention kernel are not independent. The layout is a contract, and the kernel both honors it and, increasingly, dictates it — varlen indexing for ragged batches, block tables for paging, cascade for shared prefixes, and for MLA a wholesale redefinition of what a cache entry even is. The custom-op seam is what lets vLLM swap among these without the model code knowing, and what lets the compiler in Chapter 10 capture everything around them.

What remains genuinely unsolved is selection and portability. The right backend depends on GPU generation, precision, head dimensions, and model family, and the matrix is sparse and shifting — a kernel that is fastest on Hopper may not exist on Blackwell yet, and MLA’s two-path design means a regression can hide in just the decode kernel of just one backend. The eager-mode CPU-overhead warning is a standing tax that every new backend must pay attention to. And sparse MLA is early: top-k attention trades exactness for bandwidth, and how much that costs in quality at long context is still being measured in public.

With the kernel understood as an opaque, dispatched, eager-in-the-gaps op, Chapter 10 can take the next step: capturing everything between these ops as CUDA graphs to kill the per-step launch overhead that, on a one-token decode, can otherwise outrun the GPU itself.

Further reading

  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — arXiv:2307.08691 — reorganizes attention to maximize work per memory pass and parallelize across sequence length, the basis of the varlen kernel vLLM calls by default.
  • FlashInfer: Efficient and Customizable Attention Engine for LLM Serving — arXiv:2501.01005 — a block-sparse abstraction that unifies paged decode, ragged prefill, and shared-prefix cascade into one compiled engine.
  • DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — arXiv:2405.04434 — introduces Multi-head Latent Attention, which compresses the KV cache into a single low-rank latent vector per token.