Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

KV cache and PagedAttention: memory as the real bottleneck

Chapter 3 left us with an uncomfortable asymmetry. Decode is memory-bandwidth-bound: every step reads the whole model and the whole KV cache to emit a single token, and you cannot make that token come out faster by adding FLOPs. The obvious escape hatch is concurrency. If one sequence’s decode step leaves the matrix units mostly idle, run more sequences in the same step and amortize that weight read across all of them. Batching is free throughput, right up until you run out of the one resource decode actually consumes.

That resource is not compute. It is memory, and specifically the place we store the keys and values for every token already in flight. This chapter is about why the KV cache, not the GEMM, is what caps how many requests a replica can hold at once, and about the single data-structure decision that turns that cap from “embarrassingly low” into “the highest the hardware allows.” Everything in Part II builds on the allocator we develop here.

Why the KV cache is the budget

Recall the shape of attention during decode. For each new token, the model computes a query and attends over the keys and values of every preceding token in the sequence. Those past K and V tensors do not change, so we cache them rather than recomputing them. The cache grows by one token’s worth of K and V per layer, per decode step, for the entire life of the sequence.

How much is that? The book’s own accounting lives in the KV-cache spec. A full-attention layer reports its worst-case footprint as a function of length:

# vllm/v1/kv_cache_interface.py  (FullAttentionSpec.max_memory_usage_bytes)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    ...
    return cdiv(max_model_len, self.block_size) * self.page_size_bytes

Source: vllm/v1/kv_cache_interface.py

and page_size_bytes for a block of block_size tokens is, stripped of the quantization branches:

# vllm/v1/kv_cache_interface.py  (AttentionSpec.real_page_size_bytes)
return (
    2  # one for K, one for V
    * self.block_size
    * self.num_kv_heads
    * self.head_size
    * get_dtype_size(self.dtype)
)

Source: vllm/v1/kv_cache_interface.py

Walk that formula left to right and it is just counting the numbers that have to be stored. The factor of two is K and V, since attention keeps both. block_size is how many tokens share a block. num_kv_heads is the number of distinct key/value heads the layer keeps; with grouped-query attention several query heads share one KV head, which is precisely the architectural trick that shrinks this term. head_size is the width of each head’s vector, and get_dtype_size(self.dtype) is the bytes per number (2 for bf16). Multiply all of that and you have the bytes for one block of one layer; multiply through by the number of layers $L$ and divide by block_size to get the per-token cost:

$$\text{bytes/token} = 2 \cdot L \cdot n_{\text{kv}} \cdot d_{\text{head}} \cdot b_{\text{dtype}}$$

where $n_{\text{kv}}$ is the number of KV heads, $d_{\text{head}}$ the head width, and $b_{\text{dtype}}$ the bytes per number. For a typical 7-8B model in bf16 with grouped-query attention, that lands somewhere around 100-150 KB per token. A single 8k-token context is a gigabyte or more of KV. On an 80 GB card holding ~16 GB of weights, you have maybe 60 GB left for cache after the runtime’s overhead, and that 60 GB is what you are dividing among all concurrent sequences. The KV cache, not the model, decides your concurrency.

So the real question of single-replica serving is: how do you pack the most live sequences into a fixed KV budget? And the answer turns out to depend almost entirely on how you allocate that budget.

The naive layout and the waste it bakes in

The pre-PagedAttention way to store a sequence’s KV cache is the way you would store anything else: one contiguous buffer per sequence. You do not know how long a request will run, so you reserve for the worst case, max_model_len tokens, up front.

This is a disaster for two reasons that any systems engineer will recognize as the classic fragmentation pair. First, internal fragmentation: a request that reserves 8k slots but generates 200 tokens wastes the other 7,800 for its entire lifetime. Output length is unknown at admission and wildly variable in real traffic, so this is not an edge case, it is the common case. Second, external fragmentation: contiguous buffers of varying sizes leave unusable gaps between them, exactly like a slab allocator with no fixed size class. The result, measured in the originating paper, was that real systems used a small fraction of their KV memory for actual tokens. The rest was reserved-but-empty or stranded-between-buffers. You were paying for 60 GB and serving as if you had 15.

The curve below makes the internal-fragmentation cost concrete. It plots the KV memory a single sequence consumes against the number of tokens it actually generates, holding max_model_len at 8k. Contiguous allocation reserves the worst case up front, so it sits flat at ~1 GB no matter how short the request turns out to be; paging tracks actual usage almost exactly, with waste capped at one partial block. The vertical gap between the two lines, at any length short of the maximum, is pure reserved-but-empty memory.

Illustrative: assumes ~125 KB/token (a typical 7-8B bf16 GQA model) and a 16-token block; the gap between the lines is the reserved-but-empty waste paging removes, not measured values.

The fix is the same fix operating systems applied to exactly this problem decades ago. Stop demanding contiguity. Page it.

Paging the cache: fixed blocks and a pool

PagedAttention (the technique introduced in Efficient Memory Management for Large Language Model Serving with PagedAttention, arXiv:2309.06180) carves KV memory into fixed-size blocks, each holding a small fixed number of tokens, and lets a sequence’s blocks live anywhere in the pool. A block is just a fixed-size chunk of GPU memory big enough to hold the K and V tensors for a handful of consecutive tokens. The pool is the full set of these blocks, all carved from the leftover KV budget at startup. The trick is that the blocks for one sequence do not have to sit next to each other in the pool. The sequence’s logical token order, token 0, then token 1, then token 2, is reconstructed by a per-sequence block table: a small array that says “my first block of tokens lives at physical block 7, my second at physical block 2, my third at physical block 9.” This is the direct analogue of an OS page table mapping a process’s contiguous virtual addresses onto scattered physical frames. The application thinks it has one contiguous run; the hardware stores it in whatever holes are free.

The diagram below traces that indirection for a single sequence. The logical view on the left is the contiguous token stream the model reasons about; the block table in the middle is the lookup the engine actually keeps; the pool on the right is physical GPU memory, where this sequence’s blocks are interleaved with other sequences’ blocks and free space.

flowchart LR
    subgraph Logical["Logical view (one sequence)"]
        L0["tokens 0-15"]
        L1["tokens 16-31"]
        L2["tokens 32-47"]
    end
    subgraph Table["Block table"]
        T0["slot 0 to phys 7"]
        T1["slot 1 to phys 2"]
        T2["slot 2 to phys 9"]
    end
    subgraph Pool["Physical block pool (GPU)"]
        P2["phys 2"]
        P7["phys 7"]
        P9["phys 9"]
        PX["other seqs / free"]
    end
    L0 --> T0 --> P7
    L1 --> T1 --> P2
    L2 --> T2 --> P9

Because every block is the same size, the engine never has to find a hole of a particular shape; any free block will do for any sequence. That single property is what dissolves external fragmentation.

vLLM’s pool is a flat list of block objects, one per physical block, all initially free:

# vllm/v1/core/block_pool.py
self.blocks: list[KVCacheBlock] = [
    KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)

Source: vllm/v1/core/block_pool.py

The default block size is 16 tokens (DEFAULT_BLOCK_SIZE: ClassVar[int] = 16 in vllm/config/cache.py). That number is a tuning knob with a real tradeoff: smaller blocks shrink internal fragmentation, because the only waste is now in the last, partially-filled block of each sequence, capped at 15 tokens regardless of context length. Larger blocks cut bookkeeping and let the attention kernel process longer contiguous runs. Sixteen is the compromise that has held up.

Notice what fixed-size blocks buy you. Every free block is interchangeable, so there are no size classes and no external fragmentation, ever. A request grows by grabbing blocks from the free list one at a time as it produces tokens, and returns them to the free list when it finishes. The free list itself is a hand-rolled doubly-linked list rather than a deque precisely so a block can be plucked out of the middle in $O(1)$:

# vllm/v1/core/kv_cache_utils.py  (FreeKVCacheBlockQueue docstring)
# This class organizes a list of KVCacheBlock objects to a doubly linked
# list of free blocks. We implement this class instead of using Python
# builtin deque to support removing a block in the middle of the queue
# in O(1) time.

Source: vllm/v1/core/kv_cache_utils.py

That $O(1)$-middle-removal matters because the free list does double duty as an eviction queue for prefix caching, the subject of Chapter 7. When a sequence finishes, its blocks go back on the free list, but the K/V they hold is left intact in case a later request shares the same prefix and can reuse it. Such a block is “free” in the sense that the allocator may hand it out, yet “occupied” in the sense that it still carries useful data. Reusing it means reaching into the middle of the queue to pull out that specific block, hence the linked list. So hold this thought: a “free” block is not necessarily empty.

Allocation: where the budget check actually lives

The scheduler never touches blocks directly. It asks the KVCacheManager for slots, and the answer to “can this request run this step” is literally “did allocate_slots return a list, or None.” The whole admission decision compresses to a free-block count:

# vllm/v1/core/kv_cache_manager.py  (allocate_slots)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(...)
available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks
required_blocks = num_blocks_to_allocate + watermark_blocks
if required_blocks > available_blocks:
    # Cannot allocate new blocks
    return None

Source: vllm/v1/core/kv_cache_manager.py

Read that inequality slowly, because it is the entire admission decision in four lines. num_blocks_to_allocate is how many fresh physical blocks this step needs, computed by the coordinator from how many new tokens the request is about to produce and how many partially-filled blocks it already holds. get_num_free_blocks() is the live count of blocks currently on the free list. The two correction terms are headroom: watermark_blocks is a small reserve held back so that newly-admitted waiting requests cannot grab the last blocks and starve a running sequence into immediate preemption, and reserved_blocks holds capacity for in-flight sequences during asynchronous KV transfers (Chapters 16 and 17). If demand exceeds the headroom-adjusted supply, the function returns None and the request simply does not run this step. There is no error, no partial allocation, just a clean “not now.”

The flowchart below traces that decision and what each outcome triggers downstream.

flowchart TD
    A["Scheduler: can this request run this step?"] --> B["allocate_slots asks coordinator"]
    B --> C["required = blocks_needed + watermark"]
    C --> D["available = free_blocks - reserved"]
    D --> E{"required > available?"}
    E -->|"yes"| F["return None"]
    E -->|"no"| G["pop blocks from free list, return KVCacheBlocks"]
    F --> H["scheduler preempts or defers (Chapter 6)"]
    G --> I["request runs: write K/V into its slots this step"]

This is the function that connects every later chapter’s machinery back to physical memory. A None here is what triggers the preemption-as-load-shedding loop in Chapter 6. Both correction terms are admission-control knobs bolted onto the same simple inequality.

The docstring for allocate_slots carries an ASCII diagram of the block layout that repays close reading, because it is the conceptual map for prefix caching, chunked prefill, and disaggregation all at once:

# vllm/v1/core/kv_cache_manager.py  (allocate_slots docstring)
----------------------------------------------------------------------
| < comp > | < new_comp > | < ext_comp >  | < new >  | < lookahead > |
----------------------------------------------------------------------
                                          |   < to be computed >     |
----------------------------------------------------------------------
                          |            < to be allocated >           |
----------------------------------------------------------------------

Source: vllm/v1/core/kv_cache_manager.py

comp is what the request has already computed; new_comp and ext_comp are tokens whose KV we get for free from the local prefix cache or an external store; new is this step’s actual work, and lookahead reserves slots for speculative draft tokens (Chapter 13). Every optimization later in the book is, in some sense, a different way of growing the free regions on the left so the “to be computed” region on the right shrinks. The allocator is the common substrate.

Coordinating layers and types

One subtlety the contiguous-buffer mental model hides: “the KV cache” is not one cache. It is one per attention layer, and modern models mix attention types, with full-attention layers, sliding-window layers, and even Mamba state interleaved in the same network. vLLM handles this with a KVCacheCoordinator over per-type SingleTypeKVCacheManager instances, each owning the block bookkeeping for one flavor of layer:

# vllm/v1/core/single_type_kv_cache_manager.py
class SingleTypeKVCacheManager(ABC):
    """
    An abstract base class for a manager that handle the kv cache management
    logic of one specific type of attention layer.
    """

Source: vllm/v1/core/single_type_kv_cache_manager.py

Layers that need identically-shaped slots are grouped, and the scheduler allocates at the LCM of the groups’ block sizes so a single block-table index lands consistently across every layer in a group. The payoff of this indirection is that the scheduler keeps thinking in terms of one homogeneous block budget, while the messy reality of heterogeneous layers stays quarantined behind the coordinator. The KV blocks returned to the scheduler are deliberately wrapped to “hide KVCacheManager’s internal data structure from the Scheduler” (the KVCacheBlocks docstring in kv_cache_manager.py says exactly this). Clean seam, and it is what makes the hybrid-model and MLA cases in Chapter 9 tractable.

The paged write: where logical blocks meet the kernel

So far this is all CPU-side bookkeeping. The blocks are integers in a table. The actual K and V tensors have to land in the right physical block on the GPU each step, and that write is the bridge between the allocator and the attention kernel. In the FlashAttention backend it is a single scatter:

# vllm/v1/attention/backends/flash_attn.py
reshape_and_cache_flash(
    key,
    value,
    key_cache,
    value_cache,
    slot_mapping,
    self.kv_cache_dtype,
    layer._k_scale,
    layer._v_scale,
)

Source: vllm/v1/attention/backends/flash_attn.py

The load-bearing argument is slot_mapping: a per-token tensor of flat destination indices. Where the block table maps a logical block to a physical block, slot_mapping goes one level finer and maps each individual new token to the exact flat offset in the pool where its K and V must land this step. The worker derives it from the request’s block table each step, with logic that is exactly $\text{slot} = p \cdot \text{block_size} + o$ for a token in physical block $p$ at offset $o$ (vllm/v1/worker/block_table.py computes precisely slot_ids = block_numbers * block_size + local_block_offsets): look up which physical block the next token belongs to, then add its offset inside that block. So when the kernel runs, it does not reason about blocks at all; it reads each new token’s K/V and writes it to the precomputed slot, anywhere in the pool, no contiguity assumed. This is the moment paging stops being a metaphor and becomes a memory address. Note the comment immediately above it: the op uses the shape of slot_mapping to find the real token count, so the caller does not even slice the padded K/V tensors first, a small concession to the CPU-overhead concerns from Chapter 3 that haunt this whole code path.

That is the write side, run once per layer per step to deposit the step’s new keys and values. The corresponding read side is what “PagedAttention” actually names: when computing attention, the kernel must gather K and V for all previous tokens, which are scattered across this sequence’s blocks all over the pool. It walks the block table to find each block and reads from it, instead of streaming one contiguous buffer. The diagram below traces both halves of one decode step: the scatter-write that deposits the new token’s KV, then the gather-read that pulls the whole history back for the attention math.

flowchart TD
    subgraph Write["Write side (once per layer per step)"]
        W1["new token K/V on GPU"] --> W2["slot_mapping: token to flat pool offset"]
        W2 --> W3["reshape_and_cache_flash scatters into assigned slot"]
    end
    W3 --> POOL["KV block pool (scattered across GPU)"]
    subgraph Read["Read side (PagedAttention)"]
        R1["block table for this sequence"] --> R2["gather K/V from every past block"]
        R2 --> R3["attention over full history"]
    end
    POOL --> R2

In current vLLM that gather is a feature of the FlashAttention varlen call rather than a standalone kernel, which is why this chapter sizes the cache and Chapter 9 does the kernel. The two are inseparable: paging the layout is useless without a kernel that can attend over the ragged result at bandwidth roofline.

What we bought, and what we did not

Paging effectively eliminates both fragmentation classes. Internal waste drops to under one block per sequence; external waste goes to zero by construction. In practice this is the difference between a handful of concurrent sequences and many tens of them on the same card, and because decode is memory-bound, more concurrent sequences is very nearly more throughput. It is the single highest-leverage memory decision in the engine.

It is not free. The block table is an indirection on the hottest path in the system, and the attention kernel now chases pointers instead of streaming a contiguous region, which costs some bandwidth efficiency and a great deal of kernel-engineering effort to claw back. There is per-step CPU work to build slot_mapping and the block tables, exactly the overhead the FlashAttention comment warns about. And paging does not shrink the cache, it only stops you from wasting it. The total KV footprint per token is unchanged; you have made the budget usable, not smaller. Shrinking the bytes per token is a separate lever, pulled by KV-cache quantization in Chapter 12.

There is also a quieter assumption buried in everything above: that the KV cache grows linearly with sequence length, $O(n)$ for a sequence of $n$ tokens, and must be retained in full. That is true for vanilla attention, and it is what makes memory the bottleneck. But it is a property of the architecture, not a law.

Forward pointer: not every model pays this tax

Two architectures break the linear-growth assumption, and both reappear later.

Multi-head Latent Attention compresses the per-token KV into a small shared latent vector rather than storing full K and V per head. The spec is still paged through the same machinery, but the page-size formula is entirely different, and the allocator even tracks a separate storage_block_size to account for the compression:

# vllm/v1/kv_cache_interface.py  (MLAAttentionSpec)
@property
def storage_block_size(self) -> int:
    return self.block_size // self.compress_ratio

Source: vllm/v1/kv_cache_interface.py

A sequence still grows its KV, but each token costs a fraction of what full attention charges. This is how the flagship 2026 open models keep long-context serving affordable, and it is the heart of Chapter 9.

Recurrent and state-space models, Mamba among them, are the more radical case: their state is constant size, $O(1)$ in the sequence length rather than $O(n)$. The spec advertises this directly. Where FullAttentionSpec.max_memory_usage_bytes scaled with max_model_len, the Mamba spec’s default mode returns essentially one page, regardless of how long the sequence runs:

# vllm/v1/kv_cache_interface.py  (MambaSpec.max_memory_usage_bytes)
else:
    return self.page_size_bytes * (1 + self.num_speculative_blocks)

Source: vllm/v1/kv_cache_interface.py

For such layers there is no growing cache to page, no per-token tax, and the bottleneck analysis of this chapter simply does not apply, which is exactly why hybrid models need the per-type coordinator we met above. The block-as-the-budget framing is the right default for the transformers that dominate production today, but it is a default, not a universal.

With the cache sized and paged, the obvious next question is how the scheduler decides, every single step, which sequences get to grab those blocks. That is continuous batching, and it is where we turn next.

Further reading

  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — The originating paper; introduces paged KV blocks and a block table per sequence to kill internal and external fragmentation, and measures how badly contiguous allocation wasted memory. Read it for the fragmentation taxonomy and the OS-paging analogy that this whole chapter rests on.