Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Prefix and prompt caching: paying for shared context once

By now we have an engine that pages KV memory (Chapter 4), re-plans a token-budgeted batch every step (Chapter 5), and slices long prefills into chunks while shedding load through preemption-with-recompute (Chapter 6). Each of those mechanisms made the cost of a prefill more tolerable. None of them questioned whether the prefill needed to happen at all.

It usually doesn’t. Look at the actual prompts hitting a production endpoint and you find the same bytes over and over: a 2,000-token system prompt prepended to every request, a fixed block of few-shot examples shared across a batch job, a chat history that grows by one turn but repeats the previous twelve. The compute-bound prefill that Chapter 3 told us is the expensive phase is, in real traffic, mostly recomputation of context the engine already saw seconds ago.

The insight that turns this waste into the single biggest TTFT lever you have is almost embarrassingly simple. The KV cache for a span of tokens is a pure function of those tokens (and a few side conditions we will get to). To see why, recall from Chapter 4 what a KV block actually contains: for each token in the block, the key and value vectors that token produces at every attention layer. Those vectors depend only on the token sequence up to and including that position, the model weights, and the position itself. Run the same prefix through the same model twice and you get the same numbers, down to the bit. So if two requests share a prefix, the KV blocks for that prefix are bit-identical. There is no per-request state hiding in them.

And the engine already has somewhere to keep those blocks: the idle KV pool. This is the crucial observation. When a request finishes, Chapter 4’s allocator returns its blocks to a free pool to be handed out later, but it does not zero them. A block that has been freed but not yet overwritten still physically holds valid KV state from the request that last used it. Normally that stale data is garbage we are about to clobber. But if we can find that block again by its contents before something else overwrites it, we can hand it straight to a new request and skip recomputing its prefill entirely. The free pool stops being just a free pool and becomes a content-addressed cache that happens to share storage with live allocations. The same physical block is, at any instant, either live KV for some running request, reusable cache for a future one, or scratch space about to be recycled, and the rest of this chapter is about managing those three roles on a single pool of memory.

That dual identity is the heart of the design, so it is worth seeing whole before we descend into the parts. The diagram below traces the two paths a freshly arrived request can take.

flowchart TD
    A["new request arrives"] --> B["hash its tokens block-by-block"]
    B --> C{"prefix blocks<br/>already cached?"}
    C -->|"hit"| D["reuse cached KV blocks<br/>skip their prefill"]
    C -->|"miss"| E["prefill: compute KV from scratch"]
    D --> F["compute only the uncached suffix"]
    E --> G["store new blocks in cache"]
    F --> G
    G --> H["decode normally"]
    H --> I["request finishes"]
    I --> J["blocks freed but kept as cache<br/>(not zeroed)"]
    J -.->|"later request<br/>with same prefix"| C

Naming a block by its contents

To look a block up by content, you need a name derived from content: a short, fixed-size fingerprint that two requests will compute identically if and only if their tokens match. That fingerprint is a hash. But a naive hash of one block’s 16 tokens is not enough, and seeing exactly why not is the key to this whole section.

Imagine you only hashed each block’s own tokens. Then the block holding tokens 16 through 31 would get the same name no matter what came before it in positions 0 through 15. But its KV vectors are not the same regardless of what came before, because attention at those positions looked back at the earlier tokens. Reusing that block when the earlier context differed would hand a request KV computed for someone else’s prefix. The name has to encode the entire history leading up to the block, not just the block.

vLLM solves this by chaining. It computes a name when a block fills, in get_request_block_hasher, walking the request’s tokens block-by-block and folding each block’s name into the next:

# vllm/v1/core/kv_cache_utils.py
block_hash = hash_block_tokens(
    caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys
)
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash

Source: vllm/v1/core/kv_cache_utils.py

The chaining is the whole trick. A block’s hash folds in its parent’s hash (the name of the block immediately before it), so the name of block N encodes the entire token sequence from position 0 through the end of block N, not just block N’s own 16 tokens. Each name is built from the name before it, like a tiny blockchain over the prompt. The hash itself:

# vllm/v1/core/kv_cache_utils.py
def hash_block_tokens(
    hash_function, parent_block_hash, curr_block_token_ids, extra_keys=None,
) -> BlockHash:
    if not parent_block_hash:
        parent_block_hash = NONE_HASH
    curr_block_token_ids_tuple = tuple(curr_block_token_ids)
    return BlockHash(
        hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys))
    )

Source: vllm/v1/core/kv_cache_utils.py

Two consequences fall out of the parent chaining, and both are visible in the diagram below. It shows two requests that share a system prompt (blocks 0 and 1) and then diverge: request A continues with its own question, request B with a different one. Because each block’s name is computed from its parent’s name plus its own tokens, the two requests produce identical names for blocks 0 and 1, then different names from block 2 onward where their tokens differ.

flowchart LR
    seed["NONE_HASH<br/>(seed)"] --> h0["hash 0<br/>tokens 0..15<br/>(shared system prompt)"]
    h0 --> h1["hash 1<br/>tokens 16..31<br/>(shared system prompt)"]
    h1 --> h2a["hash 2a<br/>request A question"]
    h1 --> h2b["hash 2b<br/>request B question"]
    h2a --> h3a["hash 3a<br/>request A cont."]
    h2b --> h3b["hash 3b<br/>request B cont."]

First, a cache hit on block N guarantees hits on all of blocks 0..N-1, because their contents are baked into N’s name. If request B’s hash 2b ever shows up in the cache, the only way it could have been computed is from a request that also produced hash 1 and hash 0, so those must be cached too. That is what makes prefix lookup a single forward walk rather than a search: matching the last block of a prefix proves the entire prefix matched. Second, prefixes must be identical from the very first token to match. The seed NONE_HASH feeds the very first block, so a difference anywhere in block 0 changes hash 0, which changes every name downstream. Two requests that differ only in their first token share nothing, which is exactly why you want shared context to be a literal prefix (system prompt first, then the variable part) and not, say, interleaved.

The first block has no parent, so the chain needs a seed. That is NONE_HASH, and how it’s chosen matters more than it looks:

# vllm/v1/core/kv_cache_utils.py
if hash_seed is None:
    NONE_HASH = BlockHash(os.urandom(32))
else:
    NONE_HASH = BlockHash(hash_fn(hash_seed))

Source: vllm/v1/core/kv_cache_utils.py

Within one process the seed is constant, so hashes are stable; across processes, an unset PYTHONHASHSEED gives every replica a different random seed, so block names are deliberately not portable between machines. That is fine for single-replica caching and becomes a real design question the moment you want a router to reason about which replica holds which prefix, which is Chapter 18’s problem.

Then there is extra_keys. A raw content hash would happily let one request reuse another’s KV blocks even when the KV is not actually interchangeable. generate_block_hash_extra_keys is the guard:

# vllm/v1/core/kv_cache_utils.py
extra_keys: list[Any] = (
    lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
)

Source: vllm/v1/core/kv_cache_utils.py

Three of these are worth holding onto. The LoRA name forks the cache per adapter, because the same tokens produce different KV under different fine-tuned weights (Chapter 19 leans on exactly this). The multimodal hash keys a placeholder block to the actual image or audio that filled it, so two requests with the same <image> token but different pixels don’t collide (Chapter 8 builds its encoder cache on this same hash). And cache_salt is a deliberate escape hatch: a caller who does not want their prefix shared with other tenants passes a salt that only affects the first block’s keys, poisoning the whole chain from the root and giving a clean isolation boundary without a separate cache. The honest tradeoff here is that prefix caching is, by default, cross-request and cross-user; in a multi-tenant deployment the salt is how you opt out of side-channel timing leakage, and it is opt-in, so the secure-by-default story is on you.

The longest-prefix walk

With names in hand, a new request’s cache lookup is get_computed_blocks:

# vllm/v1/core/kv_cache_manager.py
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
    self.coordinator.find_longest_cache_hit(
        request.block_hashes, max_cache_hit_length
    )
)

Source: vllm/v1/core/kv_cache_manager.py

That - 1 is small and easy to miss, but it encodes something fundamental from Chapter 3: prefill computes KV state but produces no sampled token until you run a forward pass that emits logits (the model’s score over the vocabulary for the next position). Logits come out of the forward pass on the last token of the prompt. If a request were a 100% cache hit, every token’s KV would already be sitting in cache and the engine would skip the forward pass entirely, leaving it with no logits for the next position and nothing to sample. The first decode step would have nothing to start from. So the engine caps the hit one token short of the prompt and always recomputes at least the final position, guaranteeing one real forward pass that yields the logits decode needs. Because hits are block-aligned, in practice this re-runs the whole last block, a deliberate inefficiency the code flags as future cleanup.

The walk itself, for full attention, is the loop the parent-chaining was designed for:

# vllm/v1/core/single_type_kv_cache_manager.py
for block_hash in itertools.islice(block_hashes, max_num_blocks):
    # block_hashes is a chain of block hashes. If a block hash is not
    # in the cached_block_hash_to_id, the following block hashes are
    # not computed yet for sure.
    if cached_block := block_pool.get_cached_block(
        block_hash, kv_cache_group_ids
    ):
        for computed, cached in zip(computed_blocks, cached_block):
            computed.append(cached)
    else:
        break

Source: vllm/v1/core/single_type_kv_cache_manager.py

Walk the chain, stop at the first miss. Concretely: take the request’s list of block hashes (computed exactly as in the previous section), look up the first one in cached_block_hash_to_id, the dictionary mapping a block name to the physical block currently holding that KV. If it is present, that block is a hit; record it and move to the next name. The instant a name is absent, stop. The comment states the invariant the whole design rests on, and it is the first consequence from the diagram above turned into a loop guard: a miss at block K means K+1 onward cannot be cached either (their names were derived from K’s, so if K was never stored, nothing after it could have been), so there is no point looking. The cost is linear in the number of hit blocks, one dictionary lookup each, no tree traversal and no scanning past the first miss.

The tokens those hit blocks cover are returned as num_new_computed_tokens, and here is where the cache rejoins the rest of the engine. The scheduler simply advances the request’s num_computed_tokens past them, the same counter Chapter 5 uses to track how far a request has been prefilled. To the scheduler, a request with a cache hit is indistinguishable from one whose first few prefill chunks already ran in earlier steps: it is just a request that arrives partway done. Continuous batching from Chapter 5 needs no special case at all, which is why the cache could be bolted onto the existing scheduler without disturbing it.

Keeping a freed block alive long enough to be found

The subtle part is not the lookup. It is that the blocks being looked up are simultaneously eviction candidates. This is the three-roles tension from the opening section made concrete. A cached block whose owning request has finished sits in the free queue with ref_cnt == 0 (its reference count, the number of live requests using it, has dropped to zero), available to be handed out for new allocations, and still listed in the hash map as a valid cache hit. It is in two data structures at once, playing two roles, and the two roles have to be reconciled the instant either one fires.

The state machine below is the spine of this section. A block moves between three states: live (held by one or more running requests, ref_cnt > 0), cached-and-free (ref_cnt == 0, sitting in the free queue but still findable by name), and the brief act of being recycled into a fresh live block. Two events drive the interesting transitions out of cached-and-free, and the rest of this section is just those two events: a cache hit (touch) pulls the block back to live, while an allocation under memory pressure (get_new_blocks) evicts it.

stateDiagram-v2
    [*] --> Live: allocate for a request
    Live --> Live: more requests share it, ref_cnt++
    Live --> CachedFree: request finishes, ref_cnt to 0, block kept in hash map
    CachedFree --> Live: touch, cache hit rescues it, remove from free queue
    CachedFree --> Recycled: get_new_blocks pops it for a new allocation
    Recycled --> Live: evict old hash, reset block, ref_cnt++
    Recycled --> [*]

Start with the rescue. When a hit lands on a block that is sitting in the free queue, the engine has to yank it out of that queue before it gets handed to someone else. That is touch:

# vllm/v1/core/block_pool.py
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
    for block in blocks:
        # ref_cnt=0 means this block is in the free list (i.e. eviction
        # candidate), so remove it.
        if block.ref_cnt == 0 and not block.is_null:
            self.free_block_queue.remove(block)
        block.ref_cnt += 1

Source: vllm/v1/core/block_pool.py

A hit rescues the block from the free queue and bumps its refcount, pulling it back from the brink of eviction. The FreeKVCacheBlockQueue is a hand-rolled doubly linked list precisely so this mid-queue removal is $O(1)$ (the docstring is candid that this exists to beat Python’s C-implemented deque only because it needs interior deletion). This is the “lazy LRU” the chapter title hints at: blocks are never proactively evicted on a timer or under a budget. They linger in the free list, fully reusable as cache, until the moment their physical slot is actually needed for a fresh allocation.

That moment is get_new_blocks, and the eviction is incidental:

# vllm/v1/core/block_pool.py
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
if self.enable_caching:
    for block in ret:
        self._maybe_evict_cached_block(block)
        assert block.ref_cnt == 0
        block.ref_cnt += 1

Source: vllm/v1/core/block_pool.py

You allocate by popping from the head of the free queue, and only then, lazily, do you tear down whatever cache entry the recycled block was still serving. _maybe_evict_cached_block resets the block’s hash and removes it from the lookup map, which is the Recycled --> Live edge in the state machine above: the moment of eviction and re-allocation are the same moment. The queue’s ordering is what makes this a least-recently-used (LRU) policy: freed blocks go to the tail, allocation takes from the head, so blocks that have sat unused the longest (the coldest) are reused first and recently-touched hot prefixes, near the tail, survive.

There is a deliberate refinement in which order a single finished request’s blocks rejoin the queue. They are freed “in reverse order so that the tail blocks are evicted first.” Think about which block of a shared system prompt is most valuable to keep: block 0, the root, because every future request that shares any of this prefix must hit block 0 first. The last block of the request, by contrast, holds that request’s unique tail and will likely never be hit again. Freeing in reverse pushes the unique tail blocks toward the head of the queue (evicted soonest) and leaves the shared root blocks nearer the tail (evicted last), so the part of the prefix most likely to be reused is the part the LRU protects hardest.

The payoff is measured by two counters wired straight into the scheduler’s stats:

# vllm/v1/core/kv_cache_manager.py
self.prefix_cache_stats.record(
    num_tokens=request.num_tokens,
    num_hits=num_new_computed_tokens,
    preempted=request.num_preemptions > 0,
)

Source: vllm/v1/core/kv_cache_manager.py

These surface as vllm:prefix_cache_queries and vllm:prefix_cache_hits, both counted in tokens, not requests. That denominator is the right one: it tells you what fraction of prefill tokens you avoided computing, which maps almost directly to TTFT saved. Because prefill cost is roughly linear in the number of tokens actually computed, the uncached prefill left on the critical path is just $(1 - \text{hit rate}) \times \text{prompt length}$, and the curve below shows that line falling for a 2,000-token and a 4,000-token prompt: at 80% it is the difference between a 2,000-token prefill and a 400-token one. A hit rate of 80% on a workload with long shared system prompts is not unusual.

Illustrative: assumes prefill cost is linear in tokens computed, so uncached prefill = $(1-\text{hit rate}) \times \text{prompt length}$; absolute TTFT depends on model and hardware.

Watch these two series together with the preemption counter from Chapter 6: a cache that thrashes (high queries, collapsing hits, rising preemptions) means your working set of prefixes no longer fits, and no scheduler tuning will fix a capacity problem.

Emitting the cache as a stream of events

So far everything is local to one replica. But the block lifecycle is also observable, and that turns out to matter for the fleet. When a block is cached, cache_full_blocks enqueues a BlockStored; when one is evicted, _maybe_evict_cached_block enqueues a BlockRemoved:

# vllm/distributed/kv_events.py
class BlockStored(KVCacheEvent):
    block_hashes: list[ExternalBlockHash]
    parent_block_hash: ExternalBlockHash | None
    token_ids: list[int]
    block_size: int

Source: vllm/distributed/kv_events.py

Each event carries the block’s hash and its parent’s hash. That parent pointer is not decoration: it is the same parent link the in-process hash chain used, now shipped over the wire. A stream of (hash, parent_hash) pairs is the prefix tree, serialized. Replay the stream and you rebuild the tree edge by edge, because each BlockStored tells you “this block’s name, and the name of the block it hangs off of,” which is exactly an edge. An external consumer subscribing to these events can therefore reconstruct an approximate model of which prefixes live on which replica without ever touching that replica’s memory.

The sequence below traces the payoff: a router watches the event streams from two replicas, learns that replica A cached a particular system prompt, and steers the next request carrying that prompt to A so its prefill turns into a cache hit.

sequenceDiagram
    participant RA as Replica A
    participant Router as Cache-aware router
    participant Client
    RA->>Router: BlockStored hash=H0 parent=none
    RA->>Router: BlockStored hash=H1 parent=H0
    Note over Router: rebuild prefix tree, A holds H0 to H1
    Client->>Router: request with same system prompt
    Note over Router: its prefix hashes to H0 and H1, A has them
    Router->>RA: route here for a cache hit
    RA-->>Client: fast TTFT, prefill skipped

This is the signal a cache-aware router consumes to send your request to the machine that already holds your system prompt, and we have deliberately left the wire format and the routing policy as a thread to pull in Chapter 16 (offloading evicted blocks to a larger store) and Chapter 18 (routing on this signal). For now the point is only that the engine was built to emit its cache state, not just hold it.

A different shape for the same idea: RadixAttention

vLLM’s structure is a flat hash map from block name to block, with the prefix tree existing only implicitly in the parent-chained hashes. SGLang took the other branch: make the tree explicit. Its RadixAttention keeps cached KV in a radix tree (a compressed trie) keyed on the token sequence, where shared prefixes are literally shared tree edges and the LRU eviction policy operates on tree leaves.

SGLang: Efficient Execution of Structured LM Programs (arXiv:2312.07104) introduces RadixAttention, treating the KV cache as a radix tree so prefix sharing and eviction are tree operations rather than hash-map operations.

The contrast is instructive precisely because both systems solve the same problem. vLLM’s flat map gives $O(1)$ per-block lookup and a dead-simple data structure, at the cost of fixed 16-token block granularity (the DEFAULT_BLOCK_SIZE in vllm/config/cache.py): a shared prefix that ends mid-block can’t be partially reused. A radix tree can split at arbitrary token boundaries and makes the “longest common prefix across many cached sequences” query native, which is appealing when prefixes branch heavily, as they do in structured agent programs where one prompt forks into many continuations. The cost is a more intricate structure and tree maintenance on every insert and evict. Neither is strictly better; vLLM’s bet is that block granularity is cheap enough and the flat map’s simplicity buys robustness. SGLang’s bet is that explicit tree structure pays for itself when sharing is the common case. Worth reading the paper for the tree-eviction details, which are more subtle than vLLM’s lazy-LRU and have to keep interior nodes alive while shedding leaves.

What this does and does not solve

Prefix caching is the rare optimization with almost no downside on the workloads it targets: when prefixes repeat, you pay for them once, TTFT collapses, and the only cost is hash computation you were going to do anyway for memory management. The honest caveats are about when it doesn’t help. Workloads with no shared structure (every prompt unique, no system prompt, single-turn) see hit rates near zero and pay a small hashing tax for nothing. Capacity is finite, so a large enough working set of distinct prefixes thrashes the cache into uselessness, which is the motivation for spilling cold blocks to CPU and disk in Chapter 16. The block-aligned granularity wastes the tail of every partial block, including that forced last-block recompute for logits. And multi-tenant sharing is a real side channel that the cache salt only mitigates if you remember to use it.

The deeper open problem is that prefix reuse is a property of the traffic, and the engine is at the mercy of whoever shapes prompts upstream. Put the variable part first and you have defeated the entire mechanism. Much of the value of prefix caching in practice is captured not in the engine at all but in the conventions of the API above it and the router in front of it, which is why this chapter ends pointing forward. Next we leave decoder-only text behind: Chapter 8 takes a multimodal request, runs an encoder before the LLM, and discovers that the encoder’s outputs want their own prefix cache, keyed by the very media hash we just saw threaded through extra_keys.

Further reading

  • SGLang: Efficient Execution of Structured LM Programs — arXiv:2312.07104 — RadixAttention as an explicit radix-tree alternative to vLLM’s flat hash-of-blocks, with tree-native prefix sharing and leaf eviction.
  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — the paged-block foundation (Chapter 4) that makes content-addressed block reuse possible in the first place.