Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Preface

This is a roadmap of the engineering problems you must solve to serve large language models at scale, and of the techniques the field has converged on to solve them. It is organized as a single argument that builds from the bottom up: it begins with the shape of the problem and the hardware asymmetry that drives everything else, works through every mechanism inside one well-run replica, then layers on the algorithms that lower the cost of a token, the strategies for spreading one model across many GPUs, and finally the concerns of operating a whole fleet. Each chapter assumes the ones before it.

It is written for engineers who already know how to build and operate distributed systems and now need to reason precisely about inference serving, whether you are evaluating a serving stack, tuning one in production, or trying to understand why a latency SLO is being missed. It is not an introduction to transformers or to deep learning, and it is not a how-to for any single deployment. It is a way of thinking about where the costs are and which lever moves which number.

Throughout, vLLM is treated as the canonical implementation. Where a technique is described, it is grounded in vLLM’s actual source so that the abstraction stays honest and you can go read the code yourself. vLLM is distributed under the Apache License 2.0; source quoted here is reproduced under that license, with attribution to the vLLM project and its contributors. Quotations are kept short and serve to anchor the discussion in the real control flow, not to substitute for reading the repository.

Research papers are cited as jumping-off points, not reproduced. When a chapter rests on an idea from the literature, it names the paper and explains the idea in its own words, then points you to the original for the proofs, the measurements, and the nuance. Each chapter ends with a short “Further reading” list for exactly this purpose, and those lists are aggregated in the References at the back.

This is a living roadmap of a field that is still moving quickly. The mechanisms described here are the ones that matter as of this writing, but the frontier shifts with every model release and every kernel. Treat the structure as durable and the specifics as a snapshot: the questions each chapter asks will outlast the particular answers vLLM gives today.

What “serving” means for an autoregressive model

If you have spent your career moving bytes, you carry a model of “serving” in your bones. A request arrives, a handler does a bounded amount of work, a response goes out, and the connection is recycled. The unit of accounting is the call. Capacity is requests per second. The hard problems are the ones you already know how to name: head-of-line blocking, tail latency, fan-out amplification, backpressure when a downstream slows. Everything in that worldview rests on one quiet assumption, so deeply baked in that it rarely gets stated: a request’s cost is roughly knowable up front, and the time it occupies a worker is short and bounded.

An autoregressive language model breaks that assumption at the root. Not at the edges, not in some pathological tail, but for the median request on a quiet day. This chapter is about why, and about the shape the rest of the stack is forced into as a consequence. The thesis is simple to state and expensive to absorb: an LLM request is not a function call, it is a long-lived state machine, and the cost it incurs is variable, two-phase, and stateful. Once you take that seriously, almost every later technique in this book stops looking like a clever trick and starts looking like the only thing you could have done.

A token at a time, forever

Start from the computation itself. A transformer generating text does not produce its answer in one shot. It produces one token, appends that token to its own input, and runs the whole forward pass again to produce the next one. (A token is the model’s atomic unit of text, roughly a word-piece; the model’s vocabulary is a fixed list of them, and a forward pass is one full sweep of the input through all the model’s layers, ending in a probability over that vocabulary from which the next token is drawn.) The output is fed back as input, over and over, until the model emits a special end-of-sequence token or hits a configured length cap. That feedback loop is what “autoregressive” means: each output is conditioned on, and literally becomes part of, the input for the next prediction.

The diagram below traces that loop for a three-token answer. Notice that there is no single “run the model” box; there is a box per output token, and each one feeds the next.

flowchart LR
    P["prompt: 'The capital of France is'"] --> F1["forward pass 1"]
    F1 --> T1["emit ' Paris'"]
    T1 --> F2["forward pass 2 (prompt + ' Paris')"]
    F2 --> T2["emit '.'"]
    T2 --> F3["forward pass 3 (prompt + ' Paris.')"]
    F3 --> STOP["emit end-of-sequence: done"]

This loop has two consequences that a stateless-RPC mental model has no slot for.

The first is that the length of the work is not known when the request arrives. A chat completion might generate eight tokens or eight hundred. The client did not tell you, the model does not know yet, and the only way to find out is to run it and watch for the stop condition. Your scheduler is admitting a job whose duration is a random variable it cannot observe until the job is nearly done. Every queueing intuition you have that depends on knowing or estimating service time needs an asterisk.

The second is that the work is stateful across steps. To produce the next token, the model’s attention mechanism lets the current position look back at every previous position and weigh how relevant each is. Concretely, each earlier token contributed a key vector (what it offers) and a value vector (what it carries), and the new token’s query vector is compared against all of those keys to decide how much of each value to mix in. The crucial point for serving is that the keys and values of the earlier tokens do not change once computed. So recomputing them on every step would mean re-deriving the whole sequence’s keys and values at step 1, then again at step 2, and so on, work that grows with the square of the length, $O(n^2)$ in the sequence length $n$. The standard fix is the KV cache: the per-token key and value vectors computed in earlier steps are kept in GPU memory and reused, so each new step only computes the new token’s own key, value, and query and reads the rest straight out of the cache. The two curves below show what that buys: the total key/value work to generate a sequence is quadratic in its length when you recompute every step, and linear once the cache is in place. We dissect that cache in Chapter 4.

Illustrative: the shapes are exact ($n(n{+}1)/2$ versus $n$ in units of per-position key/value computations), but the absolute counts ignore per-token constant factors and weight reads. What matters here is the structural fact it forces: a request in flight owns GPU memory that grows with every token it generates, and that memory is the live, non-reconstructable state of the computation. Drop it and the request has to start over. This is the opposite of a stateless handler. It is closer to a long-lived TCP connection with a per-connection buffer that the kernel cannot page out for free.

Two phases, two completely different cost profiles

There is a further wrinkle that has no analogue at all in conventional serving: a single request runs in two phases that stress the hardware in opposite ways.

When the request first arrives, the model has to ingest the whole prompt. Every prompt token is processed, and the KV cache for the entire prompt is populated in essentially one big parallel pass. This is prefill. Because all the prompt tokens are available at once, the engine can stack them into one large matrix multiply and keep the GPU’s arithmetic units saturated; the bottleneck is how fast the chip can do math, so prefill is compute-bound (the arithmetic units are the limiting resource). Then the request flips into decode, generating output tokens one at a time. Here is the asymmetry. A decode step processes exactly one new token, so it does only one token’s worth of arithmetic, but to do even that it must stream the entire model’s weights and the request’s entire growing KV cache out of GPU memory and through the compute units. One token’s worth of math, the whole model’s worth of memory traffic. The bottleneck is no longer how fast the chip computes but how fast it can move bytes, so decode is memory-bandwidth-bound, and because a long answer is hundreds of such steps, decode is where almost all of a long request’s wall-clock time goes.

The diagram below contrasts the two phases at a glance. The key takeaway is that the thing that makes prefill efficient (many tokens in flight at once, so the expensive-to-fetch weights are amortized over a lot of arithmetic) is exactly the thing decode structurally lacks: one token per step means the weights get re-read for almost no math.

flowchart TD
    subgraph PREFILL["prefill: once, at request start"]
        A["all N prompt tokens at once"] --> B["one big parallel matrix multiply"]
        B --> C["KV cache for the whole prompt"]
        B --> D["bottleneck: arithmetic units (compute-bound)"]
    end
    subgraph DECODE["decode: repeated, one step per output token"]
        E["one new token"] --> F["read ALL weights + growing KV cache"]
        F --> G["emit one token, append to KV cache"]
        G --> E
        F --> H["bottleneck: memory bandwidth (memory-bound)"]
    end
    PREFILL --> DECODE

Chapter 3 turns this asymmetry into a roofline argument; for now, just hold the picture: one request, two regimes, and what each regime is starved of is the opposite of the other.

This is why “fast” is a meaningless adjective for an inference system, and why Chapter 2 has to split latency into at least two SLOs that trade against each other. Time to first token is dominated by prefill and queueing. The per-token latency of the stream that follows is a decode property. They are different metrics measuring different phases of the same request, and you can be excellent at one while being terrible at the other.

The request is a state machine, and vLLM says so out loud

Because the work is long-lived, variable, and stateful, the engine cannot treat a request as a stack frame that exists for the duration of one call. It has to model the request as an explicit object with an explicit lifecycle, persisted across hundreds of scheduler ticks. vLLM makes this concrete. Every request the engine knows about carries a status drawn from a small enum, in vllm/v1/request.py:

class RequestStatus(enum.IntEnum):
    """Status of a request."""

    WAITING = enum.auto()
    WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR = enum.auto()
    WAITING_FOR_REMOTE_KVS = enum.auto()
    WAITING_FOR_STREAMING_REQ = enum.auto()
    RUNNING = enum.auto()
    PREEMPTED = enum.auto()
    # Note: anything after PREEMPTED will be considered
    # as a finished status.
    FINISHED_STOPPED = enum.auto()
    FINISHED_LENGTH_CAPPED = enum.auto()
    FINISHED_ABORTED = enum.auto()
    FINISHED_IGNORED = enum.auto()
    FINISHED_ERROR = enum.auto()
    FINISHED_REPETITION = enum.auto()

Read this enum as the table of contents for Part II. WAITING is the admission queue: the request is known to the engine but holds no GPU memory yet. RUNNING is in the live batch, actively advancing a few tokens per step. The various WAITING_FOR_* states are stalls on resources that do not exist in a stateless world: a request can be parked because its grammar has not finished compiling (Chapter 14), or because its KV cache is being fetched from a remote node (Chapters 16 and 17). And PREEMPTED is the one that should make a traffic engineer sit up: it is the state of a request that was running and got evicted to free GPU memory for something else, then sent back to the waiting queue to be re-admitted later.

The lifecycle these states describe is genuinely a state machine, and the diagram below traces the paths a request can take through it. Two edges are the ones with no stateless analogue: the loop where RUNNING advances itself step after step without leaving the state, and the demotion from RUNNING back to WAITING through PREEMPTED.

stateDiagram-v2
    [*] --> WAITING: request arrives
    WAITING --> WAITING_FOR_RESOURCE: needs grammar or remote KV
    WAITING_FOR_RESOURCE --> WAITING: resource ready
    WAITING --> RUNNING: admitted, KV memory granted
    RUNNING --> RUNNING: step advances a few tokens
    RUNNING --> PREEMPTED: evicted to free GPU memory
    PREEMPTED --> WAITING: re-queued, KV thrown away
    RUNNING --> FINISHED: stop token / length cap / abort / error
    FINISHED --> [*]

The note in the source is precise about a subtlety that the integer ordering encodes: anything numerically past PREEMPTED counts as finished, so the engine can test status > RequestStatus.PREEMPTED rather than comparing against a set of terminal values. There is not one terminal state but six, and the distinction is load-bearing because each finish reason maps to a different thing the client is told:

_FINISHED_REASON_MAP = {
    RequestStatus.FINISHED_STOPPED: FinishReason.STOP,
    RequestStatus.FINISHED_LENGTH_CAPPED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ABORTED: FinishReason.ABORT,
    RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH,
    RequestStatus.FINISHED_ERROR: FinishReason.ERROR,
    ...
}

Source: vllm/v1/request.py

The Request object that carries this status is not a thin wrapper around the input. It accumulates the state of the computation as it runs. The fields that matter most are the bookkeeping the scheduler will lean on every tick:

        self.spec_token_ids: list[int] = []
        self.num_computed_tokens = 0
        self.cache_salt: str | None = cache_salt

Source: vllm/v1/request.py

num_computed_tokens is the heart of it. It is how far through its own token sequence the request has been processed so far, and it is what makes the two-phase distinction disappear from the scheduler’s point of view. A request mid-prefill and a request mid-decode are both just “a sequence whose num_computed_tokens is behind its total length,” and the scheduler’s only job is to help each one catch up under a budget. Chapter 5 builds the whole continuous-batching scheduler on exactly this framing. The reason it can is visible right here in the request’s own state: the engine never stopped tracking how much of each request it had already done.

Note also that num_computed_tokens is a plain mutable integer, and preemption simply resets it to zero. That is the entire recovery mechanism. There is no checkpoint, no partial-result salvage. When the engine reclaims a preempted request’s memory it throws away the computed KV and the request re-prefills from scratch on readmission. Preemption-as-load-shedding (Chapter 6) is brutal precisely because the state it sheds is expensive and non-reconstructable, which is the whole point: in a stateless system, shedding load is cheap, because there is no state. Here it costs you everything that request had computed.

The wire format admits it too

The state machine is not just an internal convenience; it leaks into the contracts between processes, because vLLM is structured as a frontend and an engine core in separate OS processes (the split is documented in docs/design/arch_overview.md and docs/design/multiprocessing.md, and Chapter 11 returns to it). The frontend ships an EngineCoreRequest in, and the engine streams EngineCoreOutputs back. Look at what an output actually is, in vllm/v1/engine/__init__.py:

class EngineCoreOutput(
    msgspec.Struct,
    ...
):
    request_id: str
    new_token_ids: list[int]

    new_logprobs: LogprobsLists | None = None
    ...
    finish_reason: FinishReason | None = None
    stop_reason: int | str | None = None
    events: list[EngineCoreEvent] | None = None

This is not a response. It is an incremental update keyed by request_id, carrying the new tokens since last time and a finish_reason that is None until the request actually ends:

    @property
    def finished(self) -> bool:
        return self.finish_reason is not None

Source: vllm/v1/engine/__init__.py

A single logical request produces a stream of these, one per engine step that advanced it, over its entire multi-hundred-step life. The engine even timestamps the request’s transitions explicitly, because the lifecycle is something operators need to observe rather than infer:

class EngineCoreEventType(enum.IntEnum):
    """The type of engine core request event."""

    QUEUED = 1
    SCHEDULED = 2
    PREEMPTED = 3

Source: vllm/v1/engine/__init__.py

QUEUED, SCHEDULED, PREEMPTED are the edges of the state machine made into telemetry. The intervals between them are what become the TTFT and queue-time metrics of Chapter 2. The protocol between the engine’s two halves is, in other words, a state-machine-replication protocol, not a request/response one.

One step, many requests in flight

Tie it together at the engine’s core loop. A conventional server processes a request to completion before reclaiming the worker. vLLM does the opposite: it re-plans the entire in-flight set on every single forward pass. The loop in vllm/v1/engine/core.py is almost anticlimactic:

    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
        ...
        if not self.scheduler.has_requests():
            return {}, False
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
        ...
            model_output = future.result()
        ...
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )

Walk the three calls in order. schedule() decides which requests, and how many of each one’s tokens, get GPU time this tick, packing them into a single batch under a fixed token budget. execute_model runs one forward pass for that whole assembled batch, so a prompt still being prefilled and a request a hundred tokens into its answer ride the same matrix multiply. update_from_output then appends the freshly sampled token to each request, advances its num_computed_tokens, checks each one’s stop conditions, flips finished requests into the right FINISHED_* state, and emits the incremental outputs. Then the loop does it again.

The diagram below traces one such step and shows why a request is never “owned” by the engine for the duration of its answer. The scheduler reassembles the batch every tick from whoever is currently runnable, so across a request’s life it drifts in and out of batches it shares with strangers.

sequenceDiagram
    participant S as "scheduler"
    participant E as "model executor (GPU)"
    participant B as "in-flight requests"
    loop every step
        S->>B: pick runnable requests under token budget
        S->>E: execute_model(batch)
        E->>E: one forward pass for the whole batch
        E-->>S: one new token per request in batch
        S->>B: append token, advance num_computed_tokens
        S->>B: check stop conditions, mark FINISHED_*
        S-->>S: emit incremental EngineCoreOutputs
    end

A request lives across thousands of these iterations, sharing each one with dozens of unrelated requests at wildly different points in their own lifecycles. There is no moment where the engine is “handling your request.” There is only a steady cadence of steps, and your request is a row in the batch that some steps include and others do not.

That is the genuinely new thing. The unit of work is not the request and it is not even the token. It is the step: a fixed-budget tick that advances many long-lived state machines by a little. Capacity is tokens per step times steps per second, divided across whoever is in flight. Tail latency is not “one slow handler” but “your request kept losing the budget fight, or got preempted, or stalled waiting for a remote KV block.” The familiar problems are all still here, but they have been pushed down a level, from between requests to within the batch.

What this forces, and where we go next

If you accept that a request is a long-lived, variable-length, two-phase, memory-owning state machine, the rest of the book stops being a grab bag and becomes a consequence. The KV cache is state, so memory, not compute, is what caps concurrency, and you have to manage it like a memory allocator (Chapter 4). The work is variable-length and re-planned per step, so the scheduler is a token-level rate limiter, not a thread pool (Chapter 5), and load-shedding means evicting live state (Chapter 6). Identical prefixes are recomputed state you already paid for, so you cache and share them (Chapter 7). The two phases stress the hardware so differently that you eventually run them on separate machines entirely (Chapter 17). Every one of these is an answer to a question that a stateless-RPC stack never had to ask.

One honest caveat before we proceed. The picture of a monotonically growing per-request KV cache is the common case, not a law. Some attention variants compress that state into a small fixed-size latent (MLA, Chapter 9), and some architectures replace the growing cache with a constant-size recurrent state, which changes the memory math in Chapter 4. The state machine stays; what varies is how much memory each state costs. We will be careful about that distinction when it matters.

The problem statement is now on the table. The next chapter makes it measurable: which latencies, at which percentiles, under which load, and why the cost denominator that decides whether your fleet is affordable is GPU-seconds rather than requests. Until you can name the metric, you cannot tell whether any of the machinery in Part II is actually helping.

Further reading

  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — Frames the central claim that KV-cache memory, not arithmetic, is the binding constraint on LLM serving concurrency; read it now for the problem statement, and again with Chapter 4 for the mechanism.

The metrics that define the SLO: TTFT, ITL/TPOT, throughput, goodput, cost

Chapter 1 left us with a request that is not a single RPC but a long-lived state machine: it queues, it prefills, it decodes one token at a time, and it streams those tokens back over a connection that stays open for seconds. That shape has a consequence we now have to confront head-on. A stateless service has essentially one latency number that matters, request latency, and you reason about its distribution. An autoregressive request has at least two, and they are not the same kind of thing. One measures how long you wait before anything happens. The other measures how smoothly things keep happening once they start. They are produced by different phases of the computation, they are bounded by different hardware resources (Chapter 3 will name those resources precisely), and worst of all, they trade against each other and against throughput. You cannot maximize all three. So before we can optimize anything, we have to agree on what “fast” even means, and the honest answer is that the word is meaningless until you have named a metric, a percentile, and a load.

This chapter is about getting that vocabulary exactly right, because every later chapter is, at bottom, a story about moving one of these numbers at the expense of another. The good news for a traffic-infrastructure engineer is that none of the queueing intuition you already have is wasted. The bad news is that you have to apply it to two coupled queues at once, and the denominator you divide cost by is not requests, it is GPU-seconds.

Two latency SLOs, not one

Start with the two latencies, because they are the ones people conflate. The first is time to first token, TTFT: the interval from when the request arrives at the engine to when the user sees the first generated token. It is dominated by two things, the time the request spends waiting in the queue before it is scheduled (call this the queue delay), and the cost of the prefill forward pass over the whole prompt (the one big computation that reads every prompt token and produces the first output token). The second is inter-token latency, ITL, sometimes reported as its per-request average time per output token, TPOT: the gap between consecutive streamed tokens during the decode phase (the long tail of one-token-at-a-time steps that follows prefill). The two latencies live in two different stretches of a request’s life, and the diagram below traces that life from arrival to the last streamed token, marking exactly which interval each metric measures.

sequenceDiagram
    participant U as User / client
    participant Q as Engine queue
    participant E as GPU engine
    U->>Q: request arrives (start time st)
    Note over Q: queue delay (waiting to be scheduled)
    Q->>E: scheduled
    Note over E: prefill, one pass over the whole prompt
    E-->>U: token 1
    Note over U,E: TTFT = queue delay + prefill (measured once)
    Note over E: decode step
    E-->>U: token 2
    Note over U,E: ITL sample = gap before this token
    Note over E: decode step
    E-->>U: token 3
    Note over U,E: ITL sample
    Note over E: decode continues
    E-->>U: token N (last)
    Note over U,E: TPOT = (latency - TTFT) / (N - 1)

Read the diagram top to bottom as the timeline of a single request. Everything above the first arrow out of the engine is folded into TTFT; everything after it is the territory of ITL and TPOT. That single visual split is the whole reason one latency number is not enough.

vLLM measures both, and the cleanest way to see the definitions is in the offline benchmark client, which times them from the outside exactly as a user’s browser would. In vllm/benchmarks/lib/endpoint_request_func.py, each streamed chunk is stamped with timestamp = time.perf_counter(), and the first chunk is then treated specially:

# First token
if ttft == 0.0:
    ttft = timestamp - st
    output.ttft = ttft

# Decoding phase
else:
    output.itl.append(timestamp - most_recent_timestamp)

most_recent_timestamp = timestamp

That is the whole distinction. The code keeps a running most_recent_timestamp; the first chunk to arrive (when ttft is still its sentinel 0.0) is timed against the request’s start time st and recorded as the one TTFT, and every chunk after that records the gap since the previous chunk as one ITL sample. TTFT is therefore measured exactly once per request; ITL is measured once per token-after-the-first. A request that emits $N$ output tokens produces exactly one TTFT and $N - 1$ ITLs, so a 500-token request yields one TTFT and 499 ITLs. This asymmetry in sample counts matters later when we talk about percentiles: a single slow request floods the ITL pool with hundreds of samples but contributes only one TTFT sample, so the two distributions are not weighted the same way even when they come from the same traffic.

TPOT is then just the decode duration amortized over the output tokens, computed per request in vllm/benchmarks/serve.py:

tpot = 0
if output_len > 1:
    latency_minus_ttft = outputs[i].latency - outputs[i].ttft
    tpot = latency_minus_ttft / (output_len - 1)
    tpots.append(tpot)

Note the output_len - 1: the first token’s cost lives in TTFT, so TPOT divides the remaining latency by the remaining tokens. Writing the end-to-end request latency as $L$ and the output length as $N$, that is

$$\text{TPOT} = \frac{L - \text{TTFT}}{N - 1}.$$ ITL is the raw per-step distribution; TPOT is its per-request mean. They answer different questions. ITL tells you whether the stream stutters; TPOT tells you the steady-state rate a given request saw. A request can have a fine TPOT and still stutter badly if a few ITLs spiked, which is exactly what happens when a long prefill from another request lands in the same step and stalls everyone’s decode. Hold that thought, it is the entire motivation for chunked prefill in Chapter 6.

The engine reports the same quantities internally, not from the client’s vantage but from the engine core’s. In vllm/v1/metrics/stats.py, the per-request state object carries the timestamps, and the first-token interval is computed against the request’s arrival time:

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.first_token_latency = first_token_latency

and every later token contributes an ITL measured between engine-core timestamps:

itl = engine_core_timestamp - req_stats.last_token_ts
self.inter_token_latencies_iter.append(itl)

Source: vllm/v1/metrics/stats.py

These feed the Prometheus histograms in vllm/v1/metrics/loggers.py: vllm:time_to_first_token_seconds and vllm:inter_token_latency_seconds, plus a per-request vllm:request_time_per_output_token_seconds. It is worth a glance at the bucket boundaries, because the people who wrote them encoded their expectations about what these distributions look like. TTFT buckets run from a millisecond out to thousands of seconds:

histogram_time_to_first_token = self._histogram_cls(
    name="vllm:time_to_first_token_seconds",
    documentation="Histogram of time to first token in seconds.",
    buckets=[
        0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
        0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
        640.0, 2560.0,
    ],

The fact that TTFT needs buckets out to 2560 seconds and ITL only to 80 tells you something true: TTFT has a far longer tail, because it absorbs queueing delay. Under load, the prefill compute is roughly constant for a given prompt length, but the wait in the queue is unbounded. When people complain that “the model got slow,” they almost always mean TTFT, and they almost always mean its tail.

Why one number is a lie: percentile and load

Here is where your distributed-systems instincts pay off, and where most LLM benchmark numbers fall apart. A mean TTFT is nearly useless, for the same reason a mean request latency is useless in any tail-sensitive system: the mean hides the queue. The engine’s histograms exist precisely so you can ask for p50, p95, p99, and the benchmark client reports percentiles by default rather than just means. In serve.py the reported metric set is ttft,tpot,itl and the percentiles are computed straight from the sample arrays:

percentiles_ttft_ms=[
    (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],

But a percentile is still only half a statement. The other half is the load it was measured at. TTFT p99 at one request per second and TTFT p99 at the engine’s saturation point are different universes, because the queueing term dominates as you approach capacity, and it does so superlinearly. The intuition is the standard queueing one: call the arrival rate $\lambda$ and the rate the engine can actually drain $\mu$, and let the utilization be $\rho = \lambda / \mu$. As $\rho \to 1$, the average number of requests waiting ahead of you does not grow linearly, it blows up toward infinity, and TTFT (which contains the queue delay) blows up with it. Far below capacity, adding load barely moves p99; near capacity, a tiny increase in load sends p99 vertical. That is the classic latency-versus-throughput knee, and an LLM engine has one just like any queue. The curve below plots the standard M/M/1 mean queue delay, $W_q \propto \rho/(1-\rho)$, against utilization: notice how it is nearly flat out to $\rho \approx 0.8$ and then goes vertical, which is exactly why a TTFT number quoted without its load is meaningless.

Illustrative: the M/M/1 relation $W_q \propto \rho/(1-\rho)$ gives the correct shape of the knee, not a measured TTFT for any specific engine.

A benchmark that reports “p99 TTFT = 200ms” without naming the request rate, the prompt-length distribution, and the output-length distribution has told you almost nothing, because it has not told you where on that curve, relative to the knee, the number was taken.

The arrival process matters too, and vLLM’s benchmark is careful about it in a way worth copying. Requests are not fired in a tight loop; they are spaced by a configurable stochastic process. The get_request generator in serve.py draws inter-arrival gaps from a gamma distribution whose shape is the --burstiness knob:

theta = 1.0 / (current_request_rate * burstiness)

# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))

At burstiness = 1 the gamma collapses to an exponential, which is to say arrivals are a Poisson process, the standard well-behaved baseline. Push burstiness below 1 and the gaps become more variable, clustering requests into bursts that hammer the queue; push it above 1 and arrivals become more uniform than Poisson, which is gentler than reality. This single parameter lets you ask the question that actually predicts production behavior: not “what is p99 at 10 rps” but “what is p99 at 10 rps when the traffic is bursty.” The prompt and output length distributions come from the dataset layer in vllm/benchmarks/datasets/datasets.py, which can replay real traces (ShareGPT and friends) or synthesize controlled random lengths, because a benchmark on uniform 128-token prompts will lie to you about a workload of mixed 50-token and 8000-token prompts.

Throughput, and why it fights latency

Throughput is the third axis, and on the surface it is the friendly one: tokens per second, the thing you are paying for. vLLM tracks it as a rate over an interval in vllm/v1/metrics/loggers.py, splitting prompt (prefill) tokens from generation (decode) tokens because they cost differently:

prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(self.num_generation_tokens, now)

The reason throughput deserves its own axis, rather than being a derived consequence of latency, is the mechanism Chapter 5 will build in full: continuous batching. The engine raises throughput by packing more sequences into each forward pass, because a decode step is memory-bandwidth-bound (Chapter 3), so reading the weights once to serve a batch of 200 sequences is dramatically more efficient per token than reading them for one. Larger batches mean higher token throughput and better GPU utilization.

And there is the conflict, stated plainly. The same batching that lifts throughput lengthens latency. A bigger batch takes longer per step, so every sequence’s ITL grows. Admitting a large prefill into the batch to keep the GPU busy stalls the decodes already running, spiking their ITL and the new request’s neighbors’ TTFT. The knob that governs this directly is the per-step token budget, max_num_batched_tokens, which Chapter 5 introduces as the scheduler’s rate limiter; turning it up buys throughput and costs tail latency, turning it down does the reverse. There is no setting that wins both. Because both effects are continuous functions of batch size, the tradeoff is really two curves, not a fork: the chart below sweeps batch size on the x-axis and plots aggregate token throughput against per-sequence ITL on the two y-axes.

Illustrative: a simple model where step time grows linearly with batch size (weights read once, activation work scales with the batch). The shapes (saturating throughput, near-linearly rising ITL) are the real behavior; the absolute numbers are not measured.

The two curves make the bind concrete: throughput climbs steeply at small batches but flattens as the weight read is fully amortized, while ITL just keeps rising, so every step toward cheaper, higher-throughput operation also pushes tail latency up. No single batch size wins both axes. This is not an implementation wart to be optimized away. It is the central tension of the whole problem, and it was formalized most cleanly by the Sarathi-Serve work (arXiv:2403.02310), which named the throughput-latency tradeoff explicitly and showed that the prefill-decode interference is the thing you have to schedule around. We will quote its mechanism, chunked prefill, in Chapter 6; here it earns its place as the paper that says out loud that you cannot have it all.

Goodput: latency and throughput in one number

If throughput is “tokens per second” and latency is “but were they fast enough,” the obvious move is to combine them, and that combination has a name: goodput. Goodput is the rate of requests that complete while meeting all of their SLOs simultaneously. A request that finishes but blew its TTFT budget counts toward throughput and against you in goodput. This is the metric that actually correlates with a happy user and a defensible cost, and it is why the benchmark supports it directly.

The definition in vllm/benchmarks/serve.py is admirably blunt. You pass SLO thresholds as --goodput ttft:200 tpot:50 and the client checks each completed request against every threshold:

for req_metric in zip(*valid_metrics):
    is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
    if is_good_req:
        good_completed += 1

all(...) is the load-bearing word: a request is “good” only if it satisfies the TTFT bound and the TPOT bound and the end-to-end bound, whichever you specified. The diagram below traces one completed request through that gate to show why “good” is a conjunction, not an average: failing any single bound knocks the whole request out of the count, no matter how comfortably it cleared the others.

flowchart TD
    R["completed request"] --> A{"TTFT under budget?"}
    A -->|"no"| BAD["not counted (throughput only)"]
    A -->|"yes"| B{"TPOT under budget?"}
    B -->|"no"| BAD
    B -->|"yes"| C{"end-to-end under budget?"}
    C -->|"no"| BAD
    C -->|"yes"| GOOD["good_completed += 1"]
    GOOD --> RATE["goodput = good_completed / wall_clock_seconds"]

The goodput is then those good completions over wall-clock time, where $C_{\text{good}}$ is the count of SLO-satisfying completions and $T$ is the measurement window:

$$\text{goodput} = \frac{C_{\text{good}}}{T}.$$

request_goodput=good_completed / dur_s,

Source: vllm/benchmarks/serve.py

Goodput is the right optimization target precisely because it refuses to let you cheat. You cannot inflate it by cranking the batch size, because that bloats TTFT and the over-budget requests stop counting. You cannot inflate it by shrinking the batch to protect latency, because then you complete too few requests per second. It forces the engine toward the actual Pareto frontier of the two-latency, one-throughput tradeoff. The framing comes from the disaggregation literature, DistServe (arXiv:2401.09670), which optimized prefill and decode separately for goodput rather than raw throughput, an idea Chapter 17 takes apart; the vLLM CLI help even points you at that paper for the definition. For now the lesson is narrow and important: when you set an SLO, you are not setting one number, you are setting a conjunction, and goodput is how you measure whether the conjunction holds at a given load.

The cost denominator is GPU-seconds

The last metric is the one that pays the bills, and it is the one most often computed wrong. The instinct from stateless serving is cost-per-request, or cost-per-million-tokens as a flat rate. Both are misleading for the same reason: a request does not consume a request’s worth of a GPU, it consumes whatever slice of GPU time its tokens occupied while it was resident, and that slice depends on how well the engine packed it alongside others.

The correct denominator is GPU-seconds: the amount of GPU wall-clock time consumed, summed across the GPUs you rented. A GPU costs a fixed amount per second whether it is saturated or idle, so the only cost lever you have is how many useful tokens you extract per GPU-second, which is to say throughput per GPU, which is to say utilization. An idle GPU still bills, so every second it spends not producing tokens is pure waste; the cheapest possible operation keeps it producing useful tokens every second it is alive. This reframes every optimization in the book as a cost optimization. Prefix caching (Chapter 7) is cheaper because it skips recomputing shared tokens, spending fewer GPU-seconds per request. Quantization (Chapter 12) is cheaper because it reads fewer bytes per token off memory, fitting more sequences per GPU-second. Speculative decoding (Chapter 13) is cheaper because it verifies several tokens per forward pass. None of these changes the per-request price directly; all of them change the GPU-seconds in the denominator.

This is also why the two latency SLOs reappear inside the cost story. The cheapest way to run is at maximum batch size and maximum utilization, which is exactly the latency-hostile end of the tradeoff. Your SLOs cap how far toward cheap you are allowed to go. The signals you watch to know where you actually are on that curve are the scheduler gauges in vllm/v1/metrics/loggers.py, the queue depth and the cache pressure:

self.gauge_scheduler_running[engine_idx].set(scheduler_stats.num_running_reqs)
...
self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)

vllm:num_requests_running is the number of sequences currently in the batch (your batch occupancy), vllm:num_requests_waiting is the queue depth (your backlog), and vllm:kv_cache_usage_perc is how full the KV cache is, which is how close the engine is to running out of the memory that bounds concurrency (Chapter 4 explains why that memory, not compute, is usually the wall). Read together they locate you on the knee from the previous section. When waiting is near zero and KV usage has slack, you are on the flat part of the curve with headroom to spare. When waiting climbs and KV usage pins near 1.0, you are at or past the knee: the engine cannot admit new work without evicting old work, so TTFT’s tail is about to explode and you are either out of headroom or about to start preempting (Chapter 6’s load-shedding mechanism). These same gauges become the inputs to autoscaling in Chapter 20 and to cache-aware routing in Chapter 18, which is the sense in which this chapter’s metrics are the language the rest of the fleet speaks.

Handoff

So “fast” decomposes into a small, precise vocabulary: TTFT and ITL/TPOT as two latency SLOs produced by the two phases, throughput as the rate that fights them, goodput as the conjunction that holds them honest, and GPU-seconds as the denominator that turns all of it into money. We now have the metrics. What we do not yet have is the explanation for why the two phases behave so differently, why prefill is the one that saturates compute and decode is the one that starves on memory bandwidth. That asymmetry is the root cause of nearly every technique that follows, and it is the subject of Chapter 3.

Further reading

  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — formalizes the prefill/decode throughput-latency tension and the goodput framing; read it for why no single batch size wins both metrics.
  • DistServe: Disaggregating Prefill and Decoding for Goodput-optimized LLM Serving — arXiv:2401.09670 — the source of the goodput definition vLLM’s --goodput flag uses; read it for how separating the two phases changes what you can optimize (revisited in Chapter 17).

The hardware reality: memory-bound decode, compute-bound prefill, arithmetic intensity

Chapter 1 framed a request as a long-lived two-phase state machine, and Chapter 2 gave us the vocabulary to argue about it: TTFT for the first phase, ITL/TPOT for the second, throughput and goodput for the fleet, GPU-seconds for the bill. Those metrics described what we measure. This chapter explains why the two phases behave so differently that they need separate metrics in the first place. The answer is not in the model architecture or the scheduler. It is in the silicon, and specifically in one number that every serving engineer eventually learns to feel in their bones: the ratio of arithmetic operations to bytes moved.

Here is the claim, and it is worth stating bluntly because nearly every technique in the rest of this book is a corollary of it. Prefill is compute-bound. Decode is memory-bandwidth-bound. They are bottlenecked on different physical resources, and a serving stack that ignores this will leave one of those resources idle while the other saturates. If you internalize only one idea from Part I, make it this one.

The roofline, in one breath

You already reason about systems in terms of bottleneck resources: a proxy is CPU-bound here, network-bound there. GPUs have the same dichotomy, formalized as the roofline model. Two definitions first, because everything in this chapter hangs on them. A kernel is one unit of work the GPU runs — a matrix multiply, an attention computation, an element-wise add. Its arithmetic intensity is a property of that work: the number of floating-point operations it performs divided by the number of bytes it must read from memory to perform them. It is a ratio, $\text{FLOPs} / \text{byte}$, and it does not depend on how fast the hardware is. It only describes how much math the kernel squeezes out of each byte it fetches.

Now bring in the hardware. HBM (high-bandwidth memory) is the GPU’s main memory — the multi-gigabyte pool where weights and activations live. It is fast by RAM standards and glacial by compute standards, and that gap is the whole story. The chip has two hard ceilings: peak compute, how many FLOP/s the tensor cores can grind through, and peak memory bandwidth, how many bytes/s it can pull from HBM. Divide the first by the second and you get the machine balance point $I_{\text{balance}}$, the arithmetic intensity at which the two ceilings are reached simultaneously:

$$I_{\text{balance}} = \frac{\text{peak FLOP/s}}{\text{peak bytes/s}}$$

It is the break-even ratio: the number of FLOPs the hardware can do in the time it takes to fetch one byte.

The model is called a “roofline” because of the shape you get when you plot achievable performance against arithmetic intensity. To the left of the balance point the ceiling is a slanted line set by bandwidth; to the right it flattens into a horizontal line set by compute. A kernel whose arithmetic intensity sits to the left of the balance point finishes its handful of FLOPs long before the bytes it needs have arrived, so the tensor cores stall waiting on HBM — it is memory-bound, and the only way to go faster is to move fewer bytes. A kernel to the right has so much math per byte that the data arrives faster than the cores can consume it — it is compute-bound, and the only way to go faster is more FLOP/s. The single question that decides a kernel’s fate is which side of the balance point its arithmetic intensity lands on.

The curve below is that plot, on log-log axes for a representative datacenter GPU (~3 TB/s of HBM bandwidth, ~1000 TFLOP/s of peak compute, so a balance point near 333 FLOP/byte). The sloped left edge is the bandwidth ceiling, the flat right edge is the compute ceiling, and their corner is the machine balance. The two markers are the operating points this chapter is about: decode lands at the far left (arithmetic intensity ~1, pinned to the bandwidth ceiling with the tensor cores almost idle), while prefill, whose intensity grows with prompt length, sails past the balance point and rides the flat compute roof.

Illustrative: the ceiling and operating points use representative GPU specs to show the shape and the relative positions of the two phases, not measured kernel data.

The roofline above is the lens we apply to both phases. A modern datacenter GPU has a machine balance somewhere in the low hundreds of FLOP per byte. Hold that number; both phases get measured against it, and the rest of this chapter is the story of how prefill lands on the right of it and decode crashes to the far left.

Why prefill saturates compute

Prefill processes the whole prompt at once. A prompt of $S$ tokens flows through the model as an $S$-row activation matrix — each row is one token’s $d$-dimensional vector, where $d$ (the hidden size, a few thousand for a large model) is the width of the model’s internal representation. The heavy operations are the linear projections: a weight matrix of shape $(d \times d)$ multiplied against the $(S \times d)$ stack of token vectors. Count the work. The matmul performs on the order of $S \cdot d^2$ multiply-adds — one $d^2$ block of work for each of the $S$ rows. The bytes it must read are roughly $d^2$ weights (the weight matrix, loaded once) plus $S \cdot d$ activations (the input rows). For a large model and a prompt of any real length, the $S \cdot d^2$ arithmetic dwarfs the $d^2$ weight read.

Here is the key move, and it is what pulls prefill to the compute-bound side. You load the $(d \times d)$ weight matrix from HBM exactly once, then reuse it across all $S$ rows of the prompt. Every additional token in the prompt is more FLOPs spent against weights you have already paid to fetch. So arithmetic intensity scales with $S$: the longer the prompt, the more math per byte, and past a few dozen tokens it sails clean past the machine balance point. The GPU’s tensor cores run hot — this is the phase where a GPU does what marketing says it does: teraflops of dense linear algebra. This is also why TTFT scales with prompt length the way it does (more rows, more compute, and prefill is on the compute ceiling), and it is the physical reason Chapter 17 will argue for putting prefill on its own hardware pool.

Why decode starves on bandwidth

Now the autoregressive loop. After prefill, the model emits one token, appends it to the sequence, and runs another full forward pass to get the next — but this time the input is a single token, while the weights are still the entire model. That mismatch is the whole problem. The matmuls degrade from matrix-matrix to matrix-vector: instead of $(S \times d) @ (d \times d)$ you now compute $(1 \times d) @ (d \times d)$, a single row against the weight matrix. You read the same $d^2$ weights from HBM as before, but now you do only $d^2$ multiply-adds against that one row before moving on. There is no second, third, or $S$-th row to amortize the fetch over. Each weight is loaded from HBM, used exactly once, and discarded. Arithmetic intensity collapses to roughly 1 FLOP per byte — far below the machine balance — and so the tensor cores sit idle while the memory bus runs flat out. This is the same $(S \times d) @ (d \times d)$ projection as prefill; the only thing that changed is $S$ fell from “the whole prompt” to one.

The diagram below traces the two phases side by side, so the asymmetry is concrete: same weights, same kernel, opposite bottleneck.

flowchart LR
    subgraph P["prefill: S input rows"]
        direction TB
        P1["read (d x d) weights once"] --> P2["multiply against all S rows"]
        P2 --> P3["S x d^2 FLOPs per d^2 bytes read"]
        P3 --> P4["arithmetic intensity grows with S, compute-bound"]
    end
    subgraph D["decode: 1 input row"]
        direction TB
        D1["read (d x d) weights once"] --> D2["multiply against the single row"]
        D2 --> D3["d^2 FLOPs per d^2 bytes read"]
        D3 --> D4["arithmetic intensity ~1, memory-bound"]
    end

The brutal consequence, visible in the bottom branch of the diagram: a single decode step’s latency is, to first order, the time to stream the model’s weights through the memory bus once. Walk the floor explicitly. A 70B-parameter model in 16-bit precision (2 bytes per parameter) is ~140 GB of weights. Every decode step must read all of them at least once, because every layer’s projection needs its slice of those weights. On a GPU with ~3 TB/s of bandwidth, reading 140 GB once takes ~47 ms,

$$\frac{140 \text{ GB}}{3000 \text{ GB/s}} \approx 47 \text{ ms}$$

and that is a floor: even if the tensor cores were infinitely fast, the bus cannot deliver the weights faster, so no amount of spare compute can lower it. This is the single most important sentence for an infra engineer to absorb, because it reframes the entire problem. You are not compute-starved during the phase that dominates wall-clock time for long generations. You are bandwidth-starved, and the levers that matter are the ones that move fewer bytes.

And it is not only weights that the bus must carry. The KV cache from Chapter 1 — the stored keys and values for every prior token — must also be read on every decode step, and here is why it has to be. Attention works by having the new token compare itself against every earlier token (those are the keys) and then pull a weighted blend of what each earlier token carried (those are the values). So to generate token number $n$, the step must read the keys and values of all $n - 1$ tokens before it. That read grows with the length of the conversation, and it is paid again on every single step. Long contexts and large batches make the KV cache rival or exceed the weights in bytes moved per step — a second bandwidth bill stacked on top of the weight read, and one that, unlike the fixed weight cost, keeps growing as the sequence gets longer. That is exactly why Chapter 4 treats KV memory, not compute, as the binding constraint on concurrency, and why Chapter 12’s KV-cache quantization and Chapter 9’s MLA both attack the bytes of the cache rather than its FLOPs.

This asymmetry has a name in the literature worth reading: “Data Movement Is All You Need: A Case Study on Optimizing Transformers” (arXiv:2007.00072). The paper’s central finding is that transformer workloads are dominated not by arithmetic but by data movement — the operations that look cheap on a FLOP count (normalizations, element-wise ops, the plumbing between matmuls) are disproportionately expensive once you account for memory traffic, and operator fusion that keeps data in fast memory is where the wins live. The case study is on training a BERT encoder, but the methodology — classify every operator by its arithmetic intensity, then fuse the memory-bound ones — is exactly the roofline lens this chapter applies to inference; read it for how to measure where the bytes actually go.

Where the step actually happens in vLLM

Enough physics. The asymmetry is not an abstraction in vLLM — it is visible in the shape of the tensors that flow through one step. The forward pass itself is a single call in the model runner (vllm/v1/worker/gpu_model_runner.py):

        return self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **model_kwargs,
        )

Nothing here distinguishes prefill from decode. That is the point: the same code path runs both. What differs is the number of rows in input_ids. The scheduler hands the runner a flat batch of scheduled tokens, some of which happen to be a decode (one token for a running sequence) and some a prefill chunk (many tokens for a sequence catching up) — its own design note (vllm/v1/core/sched/scheduler.py) puts it bluntly, that “there’s no ‘decoding phase’ nor ‘prefill phase’ in the scheduler.” A batch heavy with prefill chunks has many rows and high arithmetic intensity; a batch of pure decodes has one row per sequence and craters into the memory-bound regime. The hardware reality and the scheduler’s uniformity are two views of the same thing — and Chapter 5 will make that uniform token-budgeted batch the central abstraction of the engine.

The decode-is-wasteful insight shows up most sharply in what happens after the forward pass. The model produces a hidden state for every input token, but a forward pass that processed $S$ prompt tokens still only needs to predict one next token — so only the final position’s hidden state is useful for sampling. vLLM gathers exactly those positions before computing logits (vllm/v1/worker/gpu_model_runner.py):

            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states)

And logits_indices is computed as the last token of each sequence in the batch (vllm/v1/worker/gpu_model_runner.py):

            logits_indices = query_start_loc[1:] - 1

Unpack that one-liner, because it is a tidy piece of indexing. query_start_loc is the prefix-sum of per-request token counts: if the batch holds requests of 5, 1, and 3 rows, it reads [0, 5, 6, 9], marking where each request’s rows begin in the flat batch. Drop the leading 0 ([1:] gives [5, 6, 9]) and subtract one ([4, 5, 8]) and you have the index of the last row of every request — exactly the position whose hidden state predicts that request’s next token. Everything else the prefill computed — the hidden states for all the interior prompt tokens — is thrown away for sampling purposes, because no token needs to be predicted from the middle of a prompt you already have. The comment a few lines up is candid about a related subtlety: a chunked-prefill request that has not finished its prompt produces a logit too, but the runner notes it will “ignore the sampled tokens from the partial requests.” This is the code-level confirmation of a fact this chapter has been circling: prefill produces no output token of its own. Its entire job is to populate the KV cache so that decode can begin. Chapter 11 will pick up this thread — the sampler operates only on [num_reqs, vocab] logits, one row per request — and explain why the only mandatory GPU-to-CPU sync per step is shipping back those sampled ids.

Speaking of which: the asymmetry even governs the one place the GPU must talk to the CPU. After sampling, the token ids cross back over the PCIe bus (vllm/v1/worker/gpu_model_runner.py):

                valid_sampled_token_ids = self._to_list(sampled_token_ids)

For a decode-dominated batch this transfer is tiny — one integer per sequence — yet it imposes a synchronization point that, left naive, would stall the bandwidth-bound GPU on a slow CPU round-trip. Chapter 10’s async scheduling exists precisely because, when each step is already as cheap as one weight-streaming pass, even a CPU hiccup becomes a visible fraction of ITL. You can already see the seam in the code: an alternate path caches the sampled ids on the GPU as prev_sampled_token_ids “to avoid CPU sync.” Hold that for Chapter 10.

A note on backends, and why this book stays on GPUs

vLLM is deliberately multi-backend. The same attention interface is implemented for CUDA (FlashAttention, FlashInfer), AMD’s ROCm (vllm/v1/attention/backends/rocm_attn.py), CPUs (vllm/v1/attention/backends/cpu_attn.py), Intel XPU, and TPU, selected at startup behind the dispatch seam Chapter 9 will dissect. The roofline argument is not GPU-specific — every accelerator has a compute ceiling and a bandwidth ceiling and therefore a balance point — but the numbers that make decode catastrophically memory-bound (hundreds of GB of weights versus single-token activations) are sharpest on high-bandwidth GPU memory, and the production fleet this book targets is GPU-based. So from here on the prose says “GPU” and “HBM” and means it, while you should remember the principle ports to any backend with the obvious substitutions.

The tradeoff this creates, and the one it does not solve

The asymmetry hands you a lever and a trap. The lever: decode wastes compute, so you can do more arithmetic per byte moved almost for free. Here is the mechanism that makes it free. A lone decode reads the whole weight matrix to multiply against one row. Put $B$ sequences’ decode steps in the same batch and the matmul becomes $(B \times d) @ (d \times d)$ — you still read each weight from HBM exactly once, but now you reuse it across all $B$ rows. The bytes moved barely change; the FLOPs multiply by $B$. Arithmetic intensity climbs back toward the balance point, and you have converted idle compute into served tokens at almost no extra bandwidth cost. The curve below shows that climb: decode arithmetic intensity rises roughly linearly with batch size, so a lone decode sits deep in the memory-bound regime while a batch of a few hundred reaches the balance point and finally puts the tensor cores to work. This is the throughput half of the throughput/latency tradeoff from Chapter 2, and the physical reason continuous batching (Chapter 5) is the foundational optimization. Speculative decoding (Chapter 13) is the same lever from another angle: verify several draft tokens in one weight-streaming pass, spending idle compute to amortize the memory read.

Illustrative: decode arithmetic intensity scales as roughly one FLOP/byte per batched sequence; the crossing point depends on the GPU’s balance point (here ~333) and is shown to convey the shape, not a measured kernel.

The trap: batching decode helps only until you run out of the other scarce resource. Every sequence you add to the batch carries its own KV cache, and KV cache is GPU memory — capacity, the third ceiling, distinct from both compute and bandwidth. So as you crank up batch size to chase throughput, the binding constraint silently migrates. The state diagram below traces the journey: a lone decode is bandwidth-bound; batch it and you climb back toward compute-bound, the regime you wanted; keep batching and the accumulated KV cache eventually fills HBM, and now you are capacity-bound and cannot admit another sequence at all. Which ceiling bites first depends on model size, context length, and load, and the path is not one-way — a batch of long-context decodes can be capacity-bound before it ever feels compute-bound. Naming that migration precisely is the job of the next chapter.

stateDiagram-v2
    [*] --> MemoryBound: single decode step
    MemoryBound --> ComputeBound: batch more sequences (reuse weights)
    ComputeBound --> CapacityBound: KV cache fills HBM
    CapacityBound --> MemoryBound: preempt or evict a sequence (Ch 6)
    MemoryBound: bandwidth-bound, bus saturated, cores idle
    ComputeBound: compute-bound, tensor cores busy, good throughput
    CapacityBound: capacity-bound, no HBM left for more KV cache

What the hardware does not hand you is a way to make a single decode step faster than one pass over the bytes it must read. You cannot out-schedule the speed of light on the memory bus. Every technique that genuinely lowers per-token latency, rather than raising throughput, works by reducing the bytes: quantize the weights so there are fewer to stream (Chapter 12), compress the KV cache so attention reads less (Chapters 9 and 12), or skip the read entirely for shared prefixes that are already resident (Chapter 7). Keep that taxonomy in hand. When a later chapter claims a speedup, the first question is always: did it add FLOPs to a bandwidth-bound step (cheap), or did it cut the bytes (the only thing that lowers the latency floor)?

We now know the shape of the problem and the physics underneath it. The next chapter follows the bytes to their largest and most dynamic source — the KV cache — and shows why treating it as paged memory, not one buffer per sequence, is what lets a single replica hold enough sequences in flight to keep that hungry memory bus busy.

Further reading

  • Data Movement Is All You Need: A Case Study on Optimizing Transformers — arXiv:2007.00072 — transformer workloads are dominated by data movement, not arithmetic; a BERT-training case study whose roofline/fusion methodology is the foundation for the memory-bound decode argument.

KV cache and PagedAttention: memory as the real bottleneck

Chapter 3 left us with an uncomfortable asymmetry. Decode is memory-bandwidth-bound: every step reads the whole model and the whole KV cache to emit a single token, and you cannot make that token come out faster by adding FLOPs. The obvious escape hatch is concurrency. If one sequence’s decode step leaves the matrix units mostly idle, run more sequences in the same step and amortize that weight read across all of them. Batching is free throughput, right up until you run out of the one resource decode actually consumes.

That resource is not compute. It is memory, and specifically the place we store the keys and values for every token already in flight. This chapter is about why the KV cache, not the GEMM, is what caps how many requests a replica can hold at once, and about the single data-structure decision that turns that cap from “embarrassingly low” into “the highest the hardware allows.” Everything in Part II builds on the allocator we develop here.

Why the KV cache is the budget

Recall the shape of attention during decode. For each new token, the model computes a query and attends over the keys and values of every preceding token in the sequence. Those past K and V tensors do not change, so we cache them rather than recomputing them. The cache grows by one token’s worth of K and V per layer, per decode step, for the entire life of the sequence.

How much is that? The book’s own accounting lives in the KV-cache spec. A full-attention layer reports its worst-case footprint as a function of length:

# vllm/v1/kv_cache_interface.py  (FullAttentionSpec.max_memory_usage_bytes)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
    max_model_len = vllm_config.model_config.max_model_len
    ...
    return cdiv(max_model_len, self.block_size) * self.page_size_bytes

Source: vllm/v1/kv_cache_interface.py

and page_size_bytes for a block of block_size tokens is, stripped of the quantization branches:

# vllm/v1/kv_cache_interface.py  (AttentionSpec.real_page_size_bytes)
return (
    2  # one for K, one for V
    * self.block_size
    * self.num_kv_heads
    * self.head_size
    * get_dtype_size(self.dtype)
)

Source: vllm/v1/kv_cache_interface.py

Walk that formula left to right and it is just counting the numbers that have to be stored. The factor of two is K and V, since attention keeps both. block_size is how many tokens share a block. num_kv_heads is the number of distinct key/value heads the layer keeps; with grouped-query attention several query heads share one KV head, which is precisely the architectural trick that shrinks this term. head_size is the width of each head’s vector, and get_dtype_size(self.dtype) is the bytes per number (2 for bf16). Multiply all of that and you have the bytes for one block of one layer; multiply through by the number of layers $L$ and divide by block_size to get the per-token cost:

$$\text{bytes/token} = 2 \cdot L \cdot n_{\text{kv}} \cdot d_{\text{head}} \cdot b_{\text{dtype}}$$

where $n_{\text{kv}}$ is the number of KV heads, $d_{\text{head}}$ the head width, and $b_{\text{dtype}}$ the bytes per number. For a typical 7-8B model in bf16 with grouped-query attention, that lands somewhere around 100-150 KB per token. A single 8k-token context is a gigabyte or more of KV. On an 80 GB card holding ~16 GB of weights, you have maybe 60 GB left for cache after the runtime’s overhead, and that 60 GB is what you are dividing among all concurrent sequences. The KV cache, not the model, decides your concurrency.

So the real question of single-replica serving is: how do you pack the most live sequences into a fixed KV budget? And the answer turns out to depend almost entirely on how you allocate that budget.

The naive layout and the waste it bakes in

The pre-PagedAttention way to store a sequence’s KV cache is the way you would store anything else: one contiguous buffer per sequence. You do not know how long a request will run, so you reserve for the worst case, max_model_len tokens, up front.

This is a disaster for two reasons that any systems engineer will recognize as the classic fragmentation pair. First, internal fragmentation: a request that reserves 8k slots but generates 200 tokens wastes the other 7,800 for its entire lifetime. Output length is unknown at admission and wildly variable in real traffic, so this is not an edge case, it is the common case. Second, external fragmentation: contiguous buffers of varying sizes leave unusable gaps between them, exactly like a slab allocator with no fixed size class. The result, measured in the originating paper, was that real systems used a small fraction of their KV memory for actual tokens. The rest was reserved-but-empty or stranded-between-buffers. You were paying for 60 GB and serving as if you had 15.

The curve below makes the internal-fragmentation cost concrete. It plots the KV memory a single sequence consumes against the number of tokens it actually generates, holding max_model_len at 8k. Contiguous allocation reserves the worst case up front, so it sits flat at ~1 GB no matter how short the request turns out to be; paging tracks actual usage almost exactly, with waste capped at one partial block. The vertical gap between the two lines, at any length short of the maximum, is pure reserved-but-empty memory.

Illustrative: assumes ~125 KB/token (a typical 7-8B bf16 GQA model) and a 16-token block; the gap between the lines is the reserved-but-empty waste paging removes, not measured values.

The fix is the same fix operating systems applied to exactly this problem decades ago. Stop demanding contiguity. Page it.

Paging the cache: fixed blocks and a pool

PagedAttention (the technique introduced in Efficient Memory Management for Large Language Model Serving with PagedAttention, arXiv:2309.06180) carves KV memory into fixed-size blocks, each holding a small fixed number of tokens, and lets a sequence’s blocks live anywhere in the pool. A block is just a fixed-size chunk of GPU memory big enough to hold the K and V tensors for a handful of consecutive tokens. The pool is the full set of these blocks, all carved from the leftover KV budget at startup. The trick is that the blocks for one sequence do not have to sit next to each other in the pool. The sequence’s logical token order, token 0, then token 1, then token 2, is reconstructed by a per-sequence block table: a small array that says “my first block of tokens lives at physical block 7, my second at physical block 2, my third at physical block 9.” This is the direct analogue of an OS page table mapping a process’s contiguous virtual addresses onto scattered physical frames. The application thinks it has one contiguous run; the hardware stores it in whatever holes are free.

The diagram below traces that indirection for a single sequence. The logical view on the left is the contiguous token stream the model reasons about; the block table in the middle is the lookup the engine actually keeps; the pool on the right is physical GPU memory, where this sequence’s blocks are interleaved with other sequences’ blocks and free space.

flowchart LR
    subgraph Logical["Logical view (one sequence)"]
        L0["tokens 0-15"]
        L1["tokens 16-31"]
        L2["tokens 32-47"]
    end
    subgraph Table["Block table"]
        T0["slot 0 to phys 7"]
        T1["slot 1 to phys 2"]
        T2["slot 2 to phys 9"]
    end
    subgraph Pool["Physical block pool (GPU)"]
        P2["phys 2"]
        P7["phys 7"]
        P9["phys 9"]
        PX["other seqs / free"]
    end
    L0 --> T0 --> P7
    L1 --> T1 --> P2
    L2 --> T2 --> P9

Because every block is the same size, the engine never has to find a hole of a particular shape; any free block will do for any sequence. That single property is what dissolves external fragmentation.

vLLM’s pool is a flat list of block objects, one per physical block, all initially free:

# vllm/v1/core/block_pool.py
self.blocks: list[KVCacheBlock] = [
    KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)

Source: vllm/v1/core/block_pool.py

The default block size is 16 tokens (DEFAULT_BLOCK_SIZE: ClassVar[int] = 16 in vllm/config/cache.py). That number is a tuning knob with a real tradeoff: smaller blocks shrink internal fragmentation, because the only waste is now in the last, partially-filled block of each sequence, capped at 15 tokens regardless of context length. Larger blocks cut bookkeeping and let the attention kernel process longer contiguous runs. Sixteen is the compromise that has held up.

Notice what fixed-size blocks buy you. Every free block is interchangeable, so there are no size classes and no external fragmentation, ever. A request grows by grabbing blocks from the free list one at a time as it produces tokens, and returns them to the free list when it finishes. The free list itself is a hand-rolled doubly-linked list rather than a deque precisely so a block can be plucked out of the middle in $O(1)$:

# vllm/v1/core/kv_cache_utils.py  (FreeKVCacheBlockQueue docstring)
# This class organizes a list of KVCacheBlock objects to a doubly linked
# list of free blocks. We implement this class instead of using Python
# builtin deque to support removing a block in the middle of the queue
# in O(1) time.

Source: vllm/v1/core/kv_cache_utils.py

That $O(1)$-middle-removal matters because the free list does double duty as an eviction queue for prefix caching, the subject of Chapter 7. When a sequence finishes, its blocks go back on the free list, but the K/V they hold is left intact in case a later request shares the same prefix and can reuse it. Such a block is “free” in the sense that the allocator may hand it out, yet “occupied” in the sense that it still carries useful data. Reusing it means reaching into the middle of the queue to pull out that specific block, hence the linked list. So hold this thought: a “free” block is not necessarily empty.

Allocation: where the budget check actually lives

The scheduler never touches blocks directly. It asks the KVCacheManager for slots, and the answer to “can this request run this step” is literally “did allocate_slots return a list, or None.” The whole admission decision compresses to a free-block count:

# vllm/v1/core/kv_cache_manager.py  (allocate_slots)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(...)
available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks
required_blocks = num_blocks_to_allocate + watermark_blocks
if required_blocks > available_blocks:
    # Cannot allocate new blocks
    return None

Source: vllm/v1/core/kv_cache_manager.py

Read that inequality slowly, because it is the entire admission decision in four lines. num_blocks_to_allocate is how many fresh physical blocks this step needs, computed by the coordinator from how many new tokens the request is about to produce and how many partially-filled blocks it already holds. get_num_free_blocks() is the live count of blocks currently on the free list. The two correction terms are headroom: watermark_blocks is a small reserve held back so that newly-admitted waiting requests cannot grab the last blocks and starve a running sequence into immediate preemption, and reserved_blocks holds capacity for in-flight sequences during asynchronous KV transfers (Chapters 16 and 17). If demand exceeds the headroom-adjusted supply, the function returns None and the request simply does not run this step. There is no error, no partial allocation, just a clean “not now.”

The flowchart below traces that decision and what each outcome triggers downstream.

flowchart TD
    A["Scheduler: can this request run this step?"] --> B["allocate_slots asks coordinator"]
    B --> C["required = blocks_needed + watermark"]
    C --> D["available = free_blocks - reserved"]
    D --> E{"required > available?"}
    E -->|"yes"| F["return None"]
    E -->|"no"| G["pop blocks from free list, return KVCacheBlocks"]
    F --> H["scheduler preempts or defers (Chapter 6)"]
    G --> I["request runs: write K/V into its slots this step"]

This is the function that connects every later chapter’s machinery back to physical memory. A None here is what triggers the preemption-as-load-shedding loop in Chapter 6. Both correction terms are admission-control knobs bolted onto the same simple inequality.

The docstring for allocate_slots carries an ASCII diagram of the block layout that repays close reading, because it is the conceptual map for prefix caching, chunked prefill, and disaggregation all at once:

# vllm/v1/core/kv_cache_manager.py  (allocate_slots docstring)
----------------------------------------------------------------------
| < comp > | < new_comp > | < ext_comp >  | < new >  | < lookahead > |
----------------------------------------------------------------------
                                          |   < to be computed >     |
----------------------------------------------------------------------
                          |            < to be allocated >           |
----------------------------------------------------------------------

Source: vllm/v1/core/kv_cache_manager.py

comp is what the request has already computed; new_comp and ext_comp are tokens whose KV we get for free from the local prefix cache or an external store; new is this step’s actual work, and lookahead reserves slots for speculative draft tokens (Chapter 13). Every optimization later in the book is, in some sense, a different way of growing the free regions on the left so the “to be computed” region on the right shrinks. The allocator is the common substrate.

Coordinating layers and types

One subtlety the contiguous-buffer mental model hides: “the KV cache” is not one cache. It is one per attention layer, and modern models mix attention types, with full-attention layers, sliding-window layers, and even Mamba state interleaved in the same network. vLLM handles this with a KVCacheCoordinator over per-type SingleTypeKVCacheManager instances, each owning the block bookkeeping for one flavor of layer:

# vllm/v1/core/single_type_kv_cache_manager.py
class SingleTypeKVCacheManager(ABC):
    """
    An abstract base class for a manager that handle the kv cache management
    logic of one specific type of attention layer.
    """

Source: vllm/v1/core/single_type_kv_cache_manager.py

Layers that need identically-shaped slots are grouped, and the scheduler allocates at the LCM of the groups’ block sizes so a single block-table index lands consistently across every layer in a group. The payoff of this indirection is that the scheduler keeps thinking in terms of one homogeneous block budget, while the messy reality of heterogeneous layers stays quarantined behind the coordinator. The KV blocks returned to the scheduler are deliberately wrapped to “hide KVCacheManager’s internal data structure from the Scheduler” (the KVCacheBlocks docstring in kv_cache_manager.py says exactly this). Clean seam, and it is what makes the hybrid-model and MLA cases in Chapter 9 tractable.

The paged write: where logical blocks meet the kernel

So far this is all CPU-side bookkeeping. The blocks are integers in a table. The actual K and V tensors have to land in the right physical block on the GPU each step, and that write is the bridge between the allocator and the attention kernel. In the FlashAttention backend it is a single scatter:

# vllm/v1/attention/backends/flash_attn.py
reshape_and_cache_flash(
    key,
    value,
    key_cache,
    value_cache,
    slot_mapping,
    self.kv_cache_dtype,
    layer._k_scale,
    layer._v_scale,
)

Source: vllm/v1/attention/backends/flash_attn.py

The load-bearing argument is slot_mapping: a per-token tensor of flat destination indices. Where the block table maps a logical block to a physical block, slot_mapping goes one level finer and maps each individual new token to the exact flat offset in the pool where its K and V must land this step. The worker derives it from the request’s block table each step, with logic that is exactly $\text{slot} = p \cdot \text{block_size} + o$ for a token in physical block $p$ at offset $o$ (vllm/v1/worker/block_table.py computes precisely slot_ids = block_numbers * block_size + local_block_offsets): look up which physical block the next token belongs to, then add its offset inside that block. So when the kernel runs, it does not reason about blocks at all; it reads each new token’s K/V and writes it to the precomputed slot, anywhere in the pool, no contiguity assumed. This is the moment paging stops being a metaphor and becomes a memory address. Note the comment immediately above it: the op uses the shape of slot_mapping to find the real token count, so the caller does not even slice the padded K/V tensors first, a small concession to the CPU-overhead concerns from Chapter 3 that haunt this whole code path.

That is the write side, run once per layer per step to deposit the step’s new keys and values. The corresponding read side is what “PagedAttention” actually names: when computing attention, the kernel must gather K and V for all previous tokens, which are scattered across this sequence’s blocks all over the pool. It walks the block table to find each block and reads from it, instead of streaming one contiguous buffer. The diagram below traces both halves of one decode step: the scatter-write that deposits the new token’s KV, then the gather-read that pulls the whole history back for the attention math.

flowchart TD
    subgraph Write["Write side (once per layer per step)"]
        W1["new token K/V on GPU"] --> W2["slot_mapping: token to flat pool offset"]
        W2 --> W3["reshape_and_cache_flash scatters into assigned slot"]
    end
    W3 --> POOL["KV block pool (scattered across GPU)"]
    subgraph Read["Read side (PagedAttention)"]
        R1["block table for this sequence"] --> R2["gather K/V from every past block"]
        R2 --> R3["attention over full history"]
    end
    POOL --> R2

In current vLLM that gather is a feature of the FlashAttention varlen call rather than a standalone kernel, which is why this chapter sizes the cache and Chapter 9 does the kernel. The two are inseparable: paging the layout is useless without a kernel that can attend over the ragged result at bandwidth roofline.

What we bought, and what we did not

Paging effectively eliminates both fragmentation classes. Internal waste drops to under one block per sequence; external waste goes to zero by construction. In practice this is the difference between a handful of concurrent sequences and many tens of them on the same card, and because decode is memory-bound, more concurrent sequences is very nearly more throughput. It is the single highest-leverage memory decision in the engine.

It is not free. The block table is an indirection on the hottest path in the system, and the attention kernel now chases pointers instead of streaming a contiguous region, which costs some bandwidth efficiency and a great deal of kernel-engineering effort to claw back. There is per-step CPU work to build slot_mapping and the block tables, exactly the overhead the FlashAttention comment warns about. And paging does not shrink the cache, it only stops you from wasting it. The total KV footprint per token is unchanged; you have made the budget usable, not smaller. Shrinking the bytes per token is a separate lever, pulled by KV-cache quantization in Chapter 12.

There is also a quieter assumption buried in everything above: that the KV cache grows linearly with sequence length, $O(n)$ for a sequence of $n$ tokens, and must be retained in full. That is true for vanilla attention, and it is what makes memory the bottleneck. But it is a property of the architecture, not a law.

Forward pointer: not every model pays this tax

Two architectures break the linear-growth assumption, and both reappear later.

Multi-head Latent Attention compresses the per-token KV into a small shared latent vector rather than storing full K and V per head. The spec is still paged through the same machinery, but the page-size formula is entirely different, and the allocator even tracks a separate storage_block_size to account for the compression:

# vllm/v1/kv_cache_interface.py  (MLAAttentionSpec)
@property
def storage_block_size(self) -> int:
    return self.block_size // self.compress_ratio

Source: vllm/v1/kv_cache_interface.py

A sequence still grows its KV, but each token costs a fraction of what full attention charges. This is how the flagship 2026 open models keep long-context serving affordable, and it is the heart of Chapter 9.

Recurrent and state-space models, Mamba among them, are the more radical case: their state is constant size, $O(1)$ in the sequence length rather than $O(n)$. The spec advertises this directly. Where FullAttentionSpec.max_memory_usage_bytes scaled with max_model_len, the Mamba spec’s default mode returns essentially one page, regardless of how long the sequence runs:

# vllm/v1/kv_cache_interface.py  (MambaSpec.max_memory_usage_bytes)
else:
    return self.page_size_bytes * (1 + self.num_speculative_blocks)

Source: vllm/v1/kv_cache_interface.py

For such layers there is no growing cache to page, no per-token tax, and the bottleneck analysis of this chapter simply does not apply, which is exactly why hybrid models need the per-type coordinator we met above. The block-as-the-budget framing is the right default for the transformers that dominate production today, but it is a default, not a universal.

With the cache sized and paged, the obvious next question is how the scheduler decides, every single step, which sequences get to grab those blocks. That is continuous batching, and it is where we turn next.

Further reading

  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — The originating paper; introduces paged KV blocks and a block table per sequence to kill internal and external fragmentation, and measures how badly contiguous allocation wasted memory. Read it for the fragmentation taxonomy and the OS-paging analogy that this whole chapter rests on.

Continuous batching and the token-budget scheduler

Chapter 4 left us with a paged KV cache: a pool of fixed-size blocks that lets a single replica hold many sequences in GPU memory at once without pre-reserving a contiguous buffer for each one’s worst case. That solved the memory problem. It said nothing about time — about which of those resident sequences actually gets to run on the next forward pass, and for how many tokens.

That is the scheduler’s job, and it is where the two-phase cost structure from Part I comes back to bite. Recall the asymmetry: prefill is compute-bound (one big matmul over the whole prompt) and decode is memory-bound (read the whole model and KV cache to emit one token per sequence). A naive serving loop treats these as two different operations on two different kinds of work, and the obvious design — the one almost everyone reaches for first — is static batching: collect a batch of requests, run all of them to completion, then collect the next batch. This chapter is about why that design is quietly catastrophic for LLM serving, and how vLLM replaces it with a scheduler that does not believe in prefill or decode as distinct phases at all.

Why static batching wastes a GPU

Hold the static-batching picture in your head for a moment. You batch eight requests. They have different prompt lengths and, worse, wildly different output lengths — one stops after 12 tokens, another runs for 800. Under static batching the whole batch is hostage to the longest member. Seven sequences finish and then sit in the batch as dead weight, their slots still occupied, their padding still consuming compute on every decode step, until the eighth finally emits its stop token. Meanwhile new requests pile up in a queue they are not allowed to enter, because the batch is “in flight.”

You already know this failure mode from traffic infrastructure. It is head-of-line blocking, and the batch is the convoy. The GPU’s decode throughput is roughly flat in batch size up to the memory-bandwidth roofline — adding more concurrent sequences to a memory-bound step is nearly free until you run out of bandwidth or KV cache. The curve below shows why: a decode step’s cost is dominated by reading the model weights once, so packing more sequences into that one read raises aggregate token throughput almost linearly until bandwidth saturates and the curve bends into a flat plateau. Every slot wasted on a finished sequence is throughput from that linear region you paid for and threw away.

Illustrative: the shape (near-linear rise, then a saturating knee at the bandwidth roofline) is the point, not the absolute token/s numbers, which depend on model and GPU.

The diagram below contrasts the two regimes. Read each row as one sequence and the trailing dots as forward passes ticking by. Under static batching (top), once a sequence finishes its slot stays locked, and the queued request C cannot start until the whole batch drains. Under continuous batching (bottom), a finished slot is reclaimed on the very next pass and C slides into it immediately.

flowchart TB
    subgraph STATIC["static batching: batch runs to completion, then refills"]
        direction TB
        SA["A: decode . . . . . . . done"]
        SB["B: decode done (slot wasted, still padded)"]
        SC["C: WAITING in queue . . . . . . . . starts only now"]
    end
    subgraph CONT["continuous batching: membership re-decided every pass"]
        direction TB
        CA["A: decode . . . . . . . done"]
        CB["B: decode done, slot freed"]
        CC["C: waiting, then admitted into freed slot, then decode . . ."]
    end

The fix, introduced by Orca, is to stop scheduling at the granularity of a request and start scheduling at the granularity of an iteration — a single forward pass. An iteration (also called a step) is exactly one call to the model: the engine takes the current batch, runs one forward pass, and gets back one new token for every sequence in it. Scheduling “at iteration granularity” means the scheduler gets to redraw the batch before each of those passes, rather than once per request.

Iteration-level scheduling

The key idea from Orca: A Distributed Serving System for Transformer-Based Generative Models (OSDI ’22) is what the paper calls iteration-level scheduling, and what the rest of us call continuous batching. Instead of admitting a batch and running it to completion, the scheduler re-decides the batch’s membership before every forward pass. A sequence that finished on the previous step is evicted from the batch immediately, and a waiting request can take its place on the very next iteration — no waiting for the convoy to clear. Orca pairs this with “selective batching” to handle the fact that operations like attention cannot be naively batched across sequences of different lengths the way a feed-forward layer can. Read it as the origin point for everything in this chapter; it is the paper that reframed the batch from a static container into a per-step decision. (Orca is cited by title and venue; the description here is a paraphrase, not a quotation.)

vLLM takes this idea and pushes it one step further, to a place that is cleaner than Orca’s own framing. In vLLM there is no “selective batching” special case for attention versus everything else, and — the part that matters for this chapter — there is no prefill phase and no decode phase in the scheduler at all. The design note at the top of schedule() states it directly:

# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec.

(vllm/v1/core/sched/scheduler.py)

Sit with what this collapses. A request is not “in prefill” or “in decode.” It is a pair of counters:

  • num_computed_tokens — how many of this request’s tokens the model has already pushed through a forward pass and recorded in the KV cache. It starts at 0 and only grows.
  • num_tokens_with_spec — how many tokens exist for this request right now: the prompt, plus every output token generated so far, plus any speculative tokens being proposed this step. The design note spells it out as len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).

The gap between them, num_tokens_with_spec - num_computed_tokens, is the work still owed: tokens that exist but have not yet been run through the model. A fresh request with a 2,000-token prompt and nothing computed has a gap of 2,000 — that is what we used to call “prefill.” A request mid-generation has already computed its entire prompt and all prior outputs, so its gap is just the single new token it is about to extend by: a gap of 1 (or a few, when speculation proposes several at once — Chapter 13). That is what we used to call “decode.” The scheduler’s only job, every step, is to hand out token allotments that shrink these gaps. Prefill is just a large allotment; decode is a small one. They are the same operation at different scales, and the only thing distinguishing them is the size of one subtraction.

The state diagram below traces a single request through these counters. Notice there is no “prefill state” and no “decode state” — only WAITING, RUNNING, and the gap shrinking toward zero, with the same loop handling the 2,000-token first step and every 1-token step after it.

stateDiagram-v2
    [*] --> WAITING: request arrives, computed = 0
    WAITING --> RUNNING: blocks allocated, admitted to batch
    RUNNING --> RUNNING: step, computed += allotment, gap shrinks
    RUNNING --> WAITING: preempted, no free KV blocks
    RUNNING --> FINISHED: stop token, gap closed for good
    FINISHED --> [*]: KV blocks returned to pool

This is the thesis of the chapter, and it is worth saying plainly: once you frame scheduling this way, the scheduler becomes a token-level rate limiter, and you already know how to reason about those.

The token budget

The rate limit is a single integer. At the top of every schedule() call:

token_budget = self.max_num_scheduled_tokens

(vllm/v1/core/sched/scheduler.py)

That budget is the maximum number of tokens the engine will process in one forward pass, and it is set from configuration:

self.max_num_scheduled_tokens = (
    self.scheduler_config.max_num_scheduled_tokens
    if self.scheduler_config.max_num_scheduled_tokens is not None
    else self.scheduler_config.max_num_batched_tokens
)

(vllm/v1/core/sched/scheduler.py)

So --max-num-batched-tokens is the knob, and max_num_scheduled_tokens is usually equal to it (it can be slightly smaller when speculative decoding might append tokens, per the config docstring). The default is modest — DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 in vllm/config/scheduler.py, though in real deployments it is set far higher by EngineArgs. There is a second, orthogonal limit: max_num_seqs (default 128), the maximum number of distinct sequences in a batch. The token budget caps total work per step; max_num_seqs caps the width of the batch. Decode steps usually hit the sequence cap first (one token each, lots of sequences); a single long prefill can blow the whole token budget on its own.

This is the rate-limiter framing made literal. The budget is a bucket of tokens that refills to the same full value at the start of every step. The scheduler walks its queues and spends from the bucket — one token of work charges one unit — and the moment the bucket hits zero, scheduling stops for that step: every loop in schedule() is gated on token_budget > 0. Whatever did not fit simply waits for the next refill, one forward pass later. What makes it a good rate limiter for this workload is that a token is a token: a prefill token and a decode token cost the scheduler the same one unit of budget, which is exactly the uniformity the design note promised. The bucket does not know or care whether the tokens it is paying out belong to one giant prompt or to fifty sequences each advancing by one — it only counts to its limit and stops.

Two passes: running first, then waiting

With the budget in hand, the structure of schedule() is two sequential passes over two queues. The running queue holds requests already in the batch and generating; the waiting queue holds new arrivals and previously-preempted requests that want in. The order is deliberate and is the central policy choice of the whole scheduler: spend the budget on the running requests first, and only offer the leftovers to waiting ones. The flowchart below traces one full schedule() call. Every diamond that tests the budget is the same token_budget > 0 guard; when the bucket empties, the loop stops wherever it is and the rest of the queue waits for the next step.

flowchart TD
    START["schedule begins: token_budget = max_num_scheduled_tokens"] --> P1Q{"running queue empty or budget == 0?"}
    P1Q -->|"no"| P1G["gap = num_tokens_with_spec - num_computed_tokens"]
    P1G --> P1C["allotment = min of gap and token_budget"]
    P1C --> P1A{"allocate_slots: KV blocks available?"}
    P1A -->|"yes"| P1S["add to batch, token_budget -= allotment"]
    P1S --> P1Q
    P1A -->|"no"| P1P["preempt a running request (Chapter 6), mark step preempted"]
    P1P --> P1Q
    P1Q -->|"yes, Pass 1 done"| P2Q{"any preemption this step, or budget == 0?"}
    P2Q -->|"yes"| DONE["emit SchedulerOutput plan"]
    P2Q -->|"no"| P2W{"waiting queue empty or batch at max_num_seqs?"}
    P2W -->|"no"| P2G["check prefix cache, then same gap-clamp-allocate"]
    P2G --> P2S["move request from WAITING to RUNNING, token_budget -= allotment"]
    P2S --> P2W
    P2W -->|"yes"| DONE

Pass 1 covers the requests already running:

# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
    request = self.running[req_index]
    ...
    num_new_tokens = (
        request.num_tokens_with_spec
        + request.num_output_placeholders
        - request.num_computed_tokens
    )

(vllm/v1/core/sched/scheduler.py)

There is the gap computation, straight out of the design note: num_new_tokens is how far this request still has to go. For a request that is mid-generation, that gap is 1 — it needs one decode token. For a request still chunking through a long prompt, the gap is whatever is left of the prompt. The scheduler does not care which; it just wants to schedule num_new_tokens for it, then clamps that figure down to whatever budget remains:

num_new_tokens = min(num_new_tokens, token_budget)

(vllm/v1/core/sched/scheduler.py)

That one min is the seam where prefill and decode reunite, and it is the hinge for the next chapter: a long prefill that wants 2,000 tokens but finds only 256 left in the budget gets chunked to 256 and rides alongside the decodes that already claimed the rest. We will develop chunked prefill properly in Chapter 6; for now, notice that it falls out of the rate limiter for free, with no special case.

Once a token count is settled, the request needs KV blocks for those tokens, and this is where Chapter 4’s allocator re-enters:

new_blocks = self.kv_cache_manager.allocate_slots(
    request,
    num_new_tokens,
    num_lookahead_tokens=self.num_lookahead_tokens,
)

if new_blocks is not None:
    # The request can be scheduled.
    break

(vllm/v1/core/sched/scheduler.py)

If allocate_slots returns None, the paged pool is out of blocks, and the scheduler must make room by preempting a running request — the load-shedding valve we hand off to Chapter 6. For now the happy path: blocks granted, the request joins the batch, and the budget is debited.

scheduled_running_reqs.append(request)
request_id = request.request_id
req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens

(vllm/v1/core/sched/scheduler.py)

token_budget -= num_new_tokens is the rate limiter spending from its bucket. Running requests are served first, every step, which is the policy choice that protects inter-token latency: a request that is already generating keeps its decode cadence rather than being starved by a flood of new arrivals.

Pass 2 handles the waiting queue — new arrivals and previously-preempted requests — but only with whatever budget Pass 1 left behind:

# Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
    ...
    while (self.waiting or self.skipped_waiting) and token_budget > 0:
        if len(self.running) == self.max_num_running_reqs:
            break

(vllm/v1/core/sched/scheduler.py)

Two guards worth noting. First, if Pass 1 had to preempt anyone (preempted_reqs is non-empty), Pass 2 is skipped entirely — there is no point admitting new work in a step where the engine is already shedding load. Second, the max_num_running_reqs ceiling (that’s max_num_seqs) caps batch width here, independently of the token budget. A waiting request goes through the same gap-and-clamp logic, checks the prefix cache for already-computed blocks (Chapter 7), allocates slots, and on success moves to RUNNING.

The order of the queues

Which waiting request is at the front? That is the entire content of request_queue.py, and it is deliberately boring — which is the point. The two policies are an enum:

class SchedulingPolicy(Enum):
    """Enum for scheduling policies."""

    FCFS = "fcfs"
    PRIORITY = "priority"

(vllm/v1/core/sched/request_queue.py)

FCFS is literally a deque; pop_request is popleft, add_request is append. Priority is a binary heap ordered by (priority, arrival_time). The default is "fcfs" (vllm/config/scheduler.py). The reason this is so thin is that all the interesting decisions — how many tokens, fit-or-preempt, prefix reuse — live in schedule() against the token budget. The queue only answers “who’s next,” and for a system whose hard constraints are tokens and blocks, plain arrival order is a defensible default. This is the part of the design that should feel most familiar to a traffic engineer: it is admission control with a token bucket sitting in front of an FCFS queue, and the cleverness is in the bucket, not the queue.

Note the preemption asymmetry baked into the queue API. When a request is preempted (Chapter 6), _preempt_request does self.waiting.prepend_request(request) — it goes back to the front under FCFS, so a preempted request resumes ahead of newer arrivals. Under priority there is no real front; prepend_request just re-inserts by (priority, arrival_time). Same call site, different semantics, chosen by the policy object.

The output contract and closing the loop

The scheduler does not run the model. It produces a plan and hands it to the worker. That plan is SchedulerOutput, and its shape is the scheduler-to-worker contract:

@dataclass
class SchedulerOutput:
    # list of the requests that are scheduled for the first time.
    scheduled_new_reqs: list[NewRequestData]
    # list of the requests that have been scheduled before.
    # ... we only send the diff to minimize the communication cost.
    scheduled_cached_reqs: CachedRequestData

    # req_id -> num_scheduled_tokens
    num_scheduled_tokens: dict[str, int]
    # Total number of tokens scheduled for all requests.
    total_num_scheduled_tokens: int

(vllm/v1/core/sched/output.py)

The num_scheduled_tokens dict is the per-step plan: for each request in the batch, how many of its tokens to process this pass. Note the new/cached split — a brand-new request ships its full payload once (NewRequestData), and on subsequent steps only the diff (CachedRequestData) crosses the process boundary, because the worker caches request state. That diff-minimization matters because, as Chapter 11 will show, the scheduler and the worker live in separate processes connected by ZMQ, and this object is serialized across that seam every single step.

The loop that drives it all is short enough to quote whole, in EngineCore.step():

scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
...
    model_output = future.result()
...
engine_core_outputs = self.scheduler.update_from_output(
    scheduler_output, model_output
)

(vllm/v1/engine/core.py)

Schedule, execute, update — repeat. That three-beat loop is the engine’s heartbeat, so it is worth tracing the data crossing the process boundary on a single beat. The sequence diagram below shows one step: the scheduler emits a plan, the worker runs the model and samples, and the result flows back so the scheduler can re-plan. The arrows that cross from worker to scheduler are exactly what update_from_output consumes.

sequenceDiagram
    participant S as Scheduler
    participant W as Worker (model executor)
    S->>S: schedule, spend token budget, advance num_computed_tokens
    S->>W: SchedulerOutput, num_scheduled_tokens per request
    W->>W: execute_model, one forward pass then sample
    W->>S: sampled tokens (model_output)
    S->>S: update_from_output, append token, check stop, roll back rejected drafts
    S->>S: next step, schedule again

The third call is what makes the batching continuous. update_from_output takes the sampled tokens back from the worker and reconciles each request’s state. For a normal decode it appends the new token and checks for a stop condition; for speculative decoding it also rolls num_computed_tokens back over rejected drafts:

if request.num_computed_tokens > 0:
    request.num_computed_tokens -= num_rejected

(vllm/v1/core/sched/scheduler.py)

That single subtraction is the same uniform machinery again: a rejected speculative token simply means the request did not advance as far as the plan assumed, so the gap it must catch up next step grows back by num_rejected. No phase, no special path — just an adjustment to the counter the design note built everything around.

To see why a subtraction is the natural correction here, look at the ordering. The accounting groundwork is laid before the worker even runs, in _update_after_schedule, which advances num_computed_tokens immediately at schedule time:

request.num_computed_tokens += num_scheduled_token

(vllm/v1/core/sched/scheduler.py)

The point of advancing the counter optimistically, at schedule time rather than after results come back, is latency: the engine pipelines steps, so the next schedule() can plan against the assumption that this step succeeded without blocking on its output. For an ordinary decode that assumption always holds — the token gets generated. For speculative decoding it might not, because some proposed draft tokens get rejected. So the correction is to undo the part of the optimistic advance that turned out to be wrong, which is exactly what the -= num_rejected rollback above does. Advance optimistically, walk it back only on rejection.

When a request stops, it is dropped from running, its request_id lands in finished_req_ids so the worker frees its cached state, and its KV blocks return to the pool — instantly available to the waiting request that the very next schedule() will admit into the freed slot. That is continuous batching closing the loop: no convoy, no head-of-line block, the batch reshaped to fit reality on every pass.

Tradeoffs and what is still unsolved

The token-budget design buys uniformity, but it does not buy you out of the throughput-versus-latency tension from Chapter 2 — it just relocates it onto one knob. A large max_num_batched_tokens lets big prefills run in fewer chunks and raises throughput, but a step dominated by a 4,000-token prefill stalls every concurrent decode for the duration of that step, spiking inter-token latency for everyone else. A small budget keeps decodes responsive but chunks prefills into more, smaller pieces, adding per-step overhead and slowing time-to-first-token. The two curves below show the bind: as the budget grows, throughput climbs and saturates, but the worst-case inter-token latency (a decode trapped behind a maxed-out step) climbs roughly in step with it. The scheduler gives you a clean rate limiter; it does not tell you where to set the rate. Chapter 6 is where that knob earns its keep, once chunked prefill and preemption give you finer control over how a step’s budget is spent.

Illustrative: throughput saturating while worst-case inter-token latency grows roughly linearly with the budget is the relationship to take away, not the specific values.

Two honest limitations to carry forward. First, FCFS is starvation-resistant but not fairness-aware: short cheap requests and long expensive ones share one budget with no notion of per-tenant shares — multi-tenancy is a fleet-level concern (Part V), not something this single-replica scheduler addresses. Second, the scheduler plans blind to output length: it knows a prompt’s size exactly but cannot know whether it will generate 5 tokens or 5,000, so it cannot anticipate when KV pressure is about to force a preemption. It reacts when allocate_slots returns None rather than predicting it. Learned output-length estimation to drive smarter admission and preemption remains genuinely open.

For now we have the engine’s heartbeat: a per-step, token-budgeted re-planning loop in which prefill and decode are the same operation seen at different scales. The next chapter spends that budget deliberately — slicing long prefills into chunks that ride alongside decodes, and turning preemption-with-recompute into the engine’s backpressure valve.

Further reading

  • Orca: A Distributed Serving System for Transformer-Based Generative Models — OSDI ’22 — introduces iteration-level (continuous) batching and selective batching; the conceptual origin of this chapter’s scheduler.
  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — formalizes the throughput/latency tension the token budget exposes and motivates the chunked-prefill mechanism developed in Chapter 6.

Chunked prefill, token budgets, and preemption as load-shedding

The previous chapter left us with a scheduler that re-plans the batch every step against a single number: a token budget. There is no prefill phase and no decode phase, only requests whose num_computed_tokens is chasing their num_tokens_with_spec, and a budget that says how many tokens the whole batch may advance this step. That reframing was the hard conceptual move. This chapter cashes it in.

Two problems remain unaddressed, and both are exactly the kind a traffic engineer is primed to see. The first is interference. A token budget tells you how much total work a step may contain, but it says nothing about how that work is distributed across requests. If one request shows up with a 30,000-token prompt and you let its prefill consume an entire step, every other request in the batch gets zero decode tokens that step. Their inter-token latency spikes. You have a head-of-line blocking problem dressed up as a prefill. The second problem is admission under pressure: the budget bounds compute per step, but the KV cache is a separate, finite resource, and nothing yet stops the scheduler from admitting more sequences than their growing caches will fit. When that happens, something has to give.

Chunked prefill answers the first. Preemption-with-recompute answers the second. They turn out to be the same machinery viewed from two angles, and both fall directly out of the token-budget loop you already understand.

Before diving in, a quick recap of the vocabulary this chapter leans on, all carried over from Chapter 5. A request’s num_computed_tokens is how many of its tokens the engine has already pushed through the model and written into the KV cache; num_tokens is how many it has in total (prompt plus tokens generated so far). A request is prefilling while num_computed_tokens is still climbing toward the end of its prompt, and decoding once it has caught up and is producing one new token per step. The token budget (max_num_batched_tokens) is the cap on how many tokens the whole batch may advance in a single schedule() step. The KV cache is a fixed pool of fixed-size blocks, each holding the keys and values for a handful of token positions; every token the engine computes must be written into a block, so the pool is a hard ceiling on how many tokens can be live at once. Keep those four ideas in hand and the rest of the chapter is bookkeeping on top of them.

Slicing the long prefill

Start with the interference problem, because its fix is almost embarrassingly small once the budget exists. A prefill of length $L$ is just $L$ tokens of work that all happen to belong to one request. The scheduler does not actually care that they belong to one request; it cares about the budget. So instead of scheduling all $L$ at once, it schedules $\min(L, \text{remaining budget})$ of them, advances num_computed_tokens by that much, and comes back for the rest on the next step. The prompt gets prefilled in chunks — a chunk being just one budget-sized slice of a prefill — and crucially those slices share each step with the decodes of every other running request.

The diagram below contrasts the two ways a 30,000-token prefill can hit the schedule. Without chunking, the long prefill monopolizes a step and the other requests’ decodes wait; with chunking, the same prefill is dribbled in a slice at a time so each step still has room for everyone’s decode token, keeping inter-token latency (ITL) smooth.

flowchart TD
    subgraph NoChunk["Without chunking (one step)"]
        A["30000-token prefill fills the whole step"] --> B["every other request's decode waits"]
        B --> C["inter-token latency spikes for the batch"]
    end
    subgraph Chunk["With chunking (many steps)"]
        D["step 1: prefill slice (budget) plus decodes"] --> E["step 2: next prefill slice plus decodes"]
        E --> F["step 3: next prefill slice plus decodes"]
        F --> G["prefill finishes, decodes never stalled"]
    end

Here is the clamp, in the WAITING-request path of the scheduler:

# vllm/v1/core/sched/scheduler.py
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
    num_new_tokens = threshold
...
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

Source: vllm/v1/core/sched/scheduler.py

That min(num_new_tokens, token_budget) is the whole idea. Walk it left to right. num_new_tokens starts as everything the request still owes — request.num_tokens - num_computed_tokens, the gap between how long the request is and how much of it has been computed. The long_prefill_token_threshold clamp (we will come back to it) optionally tightens that, and then min(..., token_budget) caps it at whatever budget is left for this step. The result is the number of tokens this request gets to advance now. Whatever is left over simply stays uncomputed: the request stays in RUNNING, and the next schedule() call picks it up exactly where it left off, because the engine advances num_computed_tokens by the amount actually run after each step. So a long prefill makes progress one budget-sized bite per step, automatically, with no state machine tracking “which chunk are we on” — the chunk position is num_computed_tokens. There is no separate “chunked prefill” code path in any meaningful sense; chunking is just what the budget cap does to a prefill that is larger than the budget.

The one thing a flag controls is whether to allow it at all. enable_chunked_prefill is on by default, and when it is off the WAITING path simply breaks instead of slicing once a prompt exceeds the budget (the line elided by ... above), refusing to start any prefill it cannot finish in a single step — back to the all-or-nothing behavior the chunking fix was meant to replace. The running-request pass above it uses the identical pattern:

# vllm/v1/core/sched/scheduler.py
num_new_tokens = (
    request.num_tokens_with_spec
    + request.num_output_placeholders
    - request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
    num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)

Source: vllm/v1/core/sched/scheduler.py

A decode request asks for one token (its computed count is one behind its real length), gets it, and moves on. A prefilling request asks for thousands, gets a budget-sized chunk, and stays. Same three lines. The scheduler genuinely does not distinguish them, which is the payoff the previous chapter promised.

Notice the second clamp, the one gated on long_prefill_token_threshold. The budget cap alone is enough to chunk a prefill, but it is a blunt instrument: if the budget is large and mostly idle, a single huge prefill can still eat most of a step before any decode runs. long_prefill_token_threshold is a tighter, per-request ceiling specifically for prompts the operator considers “long.” Its default is 0, meaning disabled, and the 0 < guard is what makes that work: the clamp only fires when someone has set a positive threshold. When concurrent partial prefills are turned on, the config quietly derives one for you:

# vllm/config/scheduler.py
if self.max_num_partial_prefills > 1:
    if self.long_prefill_token_threshold == 0:
        self.long_prefill_token_threshold = int(max_model_len * 0.04)

Source: vllm/config/scheduler.py

Four percent of the model’s context length. The number is a heuristic, not a law. The point is that you now have two independent knobs over prefill granularity: the global max_num_batched_tokens budget that bounds the whole step, and long_prefill_token_threshold that bounds any one prefill within it. Tightening the threshold guarantees more room for decodes to ride alongside a long prefill, at the cost of stretching that prefill’s time-to-first-token across more steps. Which is exactly the tradeoff named in this part’s recurring theme: you are buying ITL stability for the batch by spending TTFT on the long request. There is no free chunk size; there is only where you want to sit on that curve. The curve below makes that frontier concrete for a 30,000-token prefill: each point is one chunk size, and moving toward smaller chunks slides you left (smoother batch ITL) and up (longer TTFT for the long prompt). No setting escapes the frontier; you only pick a seat on it.

Illustrative: the shape of the TTFT-vs-ITL frontier and its direction with chunk size are real; the millisecond values are representative, not measured. Point labels are the per-step prefill chunk size in tokens.

This is the mechanism that Sarathi-Serve formalized and that we flagged back in Chapter 2 when we first met the throughput-latency tension. The paper’s framing is stall-free batching: by admitting prefill work only in budgeted slices that fit alongside ongoing decodes, you stop prefills from stalling decodes, and the decode stream stays smooth under prefill-heavy load. vLLM’s scheduler is, in this respect, an implementation of that idea, with the budget cap doing the slicing and the threshold giving you a sharper knife when you want one.

When the cache, not the budget, runs out

Chunking handles compute interference. It does nothing for the other scarce resource. Every token you prefill or decode needs a KV slot, and slots come from a fixed pool of blocks. The token budget bounds how much you compute per step; it says nothing about whether the cumulative KV footprint of all running requests still fits. Under sustained load — many long sequences all decoding, all growing their caches one block at a time — the pool eventually cannot satisfy the next allocation. The engine has over-committed, and now it must shed.

The shedding happens inside the allocation loop. When the scheduler tries to give a running request its next chunk, it asks the KV cache manager for blocks (allocate_slots), and the manager returns None if there is not enough free memory to hold them. That None is the trigger that turns a memory shortage into an eviction:

# vllm/v1/core/sched/scheduler.py
while True:
    new_blocks = self.kv_cache_manager.allocate_slots(
        request,
        num_new_tokens,
        num_lookahead_tokens=self.num_lookahead_tokens,
    )
    if new_blocks is not None:
        # The request can be scheduled.
        break
    # The request cannot be scheduled.
    # Preempt the lowest-priority request.
    if self.policy == SchedulingPolicy.PRIORITY:
        preempted_req = max(
            self.running,
            key=lambda r: (r.priority, r.arrival_time),
        )
        self.running.remove(preempted_req)
        ...
    else:
        preempted_req = self.running.pop()
    self._preempt_request(preempted_req, scheduled_timestamp)
    preempted_reqs.append(preempted_req)
    if preempted_req == request:
        # No more request to preempt. Cannot schedule this request.
        break

Source: vllm/v1/core/sched/scheduler.py

Read this as a backpressure valve. The engine cannot make memory; the only thing it can do is reclaim it from a request that currently holds some. So when an allocation fails, it picks a victim, frees that victim’s blocks, and tries the allocation again — looping until either the allocation finally succeeds or there is no one left to evict. Under the default FCFS policy it evicts from the tail of the running list — self.running.pop() takes the most-recently-admitted, which is the least-senior request, a sensible victim choice that protects the requests closest to finishing. Under priority scheduling it instead picks the genuinely lowest-priority request (highest priority value, ties broken by arrival time). And there is a self-protecting terminal case: if the request the scheduler is trying to schedule is itself the lowest-priority thing in the batch, the victim picker can land on it, at which point preempted_req == request and the loop gives up. The engine would rather make no progress on this request than evict something more important in order to serve it.

The diagram below traces that loop: each failed allocation sheds one request and retries, and the loop has exactly two exits — the allocation succeeding, or the scheduler running out of anything cheaper to evict than the request it is working on.

flowchart TD
    A["need blocks for next chunk"] --> B["allocate_slots(request)"]
    B --> C{"got blocks?"}
    C -->|yes| D["schedule request: proceed"]
    C -->|no| E["pick victim (FCFS tail or lowest priority)"]
    E --> F["_preempt_request(victim): free its blocks"]
    F --> G{"victim is the request itself?"}
    G -->|no| B
    G -->|yes| H["give up: cannot schedule this request now"]

What does preemption actually do? This is the part that distinguishes LLM serving from a generic load-shedding story, so it is worth reading line by line:

# vllm/v1/core/sched/scheduler.py
def _preempt_request(self, request: Request, timestamp: float) -> None:
    assert request.status == RequestStatus.RUNNING, (
        "Only running requests can be preempted"
    )
    self.kv_cache_manager.free(request)
    self.encoder_cache_manager.free(request)
    self._inflight_prefills.discard(request)
    request.status = RequestStatus.PREEMPTED
    request.num_computed_tokens = 0
    if request.spec_token_ids:
        request.spec_token_ids = []
    request.num_preemptions += 1
    if self.log_stats:
        request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
    # Put the request back to the waiting queue.
    self.waiting.prepend_request(request)

Source: vllm/v1/core/sched/scheduler.py

Three things matter here. First, kv_cache_manager.free(request) returns every block the request held to the pool. That is the entire point — it is where the reclaimed memory comes from, the actual payload of the whole eviction. Second, request.num_computed_tokens = 0. The request does not pause; it is rewound to the very beginning. Every token it had already prefilled and every token it had already decoded is now marked uncomputed, so all of it must be recomputed from scratch when the request gets readmitted. This is preemption with recompute, and it is a deliberate design choice: rather than swapping the KV cache out to host memory and copying it back later (which costs PCIe bandwidth and a good deal of complexity), vLLM throws the cache away and pays to rebuild it. The bet is that recompute on a GPU is often cheaper than a round trip over PCIe, and that preemption is rare enough that the recompute cost stays in the noise. Third, self.waiting.prepend_request(request) puts the victim at the front of the waiting queue, not the back. A preempted request jumps the line ahead of never-started requests, because it has already waited once and we want it to resume as soon as memory frees up — which it just did, from this very preemption.

The diagram below follows a single request through that cycle. Note that preemption is not a terminal state: the request goes RUNNING to PREEMPTED to the front of WAITING and then back to RUNNING, losing its computed progress but keeping its place near the head of the line. A request can make this loop more than once under sustained pressure, which is exactly what num_preemptions counts.

stateDiagram-v2
    [*] --> WAITING: request arrives
    WAITING --> RUNNING: admitted, blocks allocated
    RUNNING --> RUNNING: prefill chunk or decode token
    RUNNING --> PREEMPTED: chosen as victim, free blocks, reset num_computed_tokens to 0
    PREEMPTED --> WAITING: prepended to front of queue
    RUNNING --> [*]: finished (EOS or max tokens)

The recompute cost is not entirely wasted, and this is where Chapter 7’s prefix caching quietly earns its keep ahead of schedule. When a preempted request is readmitted, its prompt blocks may still be sitting in the prefix cache (preemption frees the request’s references, but blocks linger until evicted), so the “recompute from zero” can land a long prefix-cache hit and skip most of the work it nominally threw away. The status counter num_preemptions is bumped on the way out, both for the metric we will get to and so that downstream prefix-cache accounting can tag this request as having been preempted — you saw preempted=request.num_preemptions > 0 threaded through the stats calls in the previous chapter’s code.

Admission control: not over-committing in the first place

Preemption is the valve of last resort. The cheaper move is to not over-admit, and the KV cache manager has two mechanisms for that. The first is the watermark — a slice of the pool the manager refuses to hand out to newly admitted requests:

# vllm/v1/core/kv_cache_manager.py
watermark_blocks = 0
# The watermark is applied to waiting/preempted requests only, and only
# when there's at least one request already scheduled.
if has_scheduled_reqs and request.status in (
    RequestStatus.WAITING,
    RequestStatus.PREEMPTED,
):
    watermark_blocks = self.watermark_blocks

Source: vllm/v1/core/kv_cache_manager.py

and the gate itself:

# vllm/v1/core/kv_cache_manager.py
available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks
required_blocks = num_blocks_to_allocate + watermark_blocks
if required_blocks > available_blocks:
    # Cannot allocate new blocks
    return None

Source: vllm/v1/core/kv_cache_manager.py

The asymmetry is the clever part. The watermark blocks are added to required_blocks — raising the bar an allocation must clear — only for requests in WAITING or PREEMPTED state, that is, requests being admitted into the running set, and only when something is already running. An already-running request decoding its next token faces a watermark_blocks of 0 and is never blocked by the reserve; it can dip into it freely. So the reserve is invisible to requests that already have a foot in the door and a wall to requests trying to get in. The effect is hysteresis: a deliberate gap between the memory level at which you stop admitting new work and the level at which you actually run out. You stop letting new sequences in while a band of memory is still free, which leaves headroom for the sequences you already committed to keep growing for a few more steps before they collide. Without it, the engine would admit greedily right up to the last free block and then thrash, preempting on nearly every step as running requests fight over the last slivers of memory. The watermark trades a little admitted concurrency for far fewer preemptions. Its default is 0.0, disabled, because the right reserve depends on your sequence-length distribution — but on memory-pressured deployments it is one of the first knobs to reach for when you see preemptions climbing.

The second mechanism is the full-sequence admission check, and it exists to plug a hole that chunking itself opened. Recall that admitting a request only allocates blocks for its first chunk. That is great for getting work started cheaply, but it means the admission decision is made against a tiny slice of the request’s eventual footprint: a 30,000-token prompt looks, at admission time, like a request needing only one budget-sized chunk’s worth of blocks. So the allocation succeeds even when the request’s full length could never fit in the pool. It sails in, decodes happily for a while as its cache grows chunk by chunk, and then preempts something (or gets preempted) once that cache outgrows the room that was actually available — a thrash that the engine could have foreseen at admission. The full_sequence_must_fit path closes that trap by sizing the admission check against the whole sequence up front rather than the first chunk:

# vllm/v1/core/kv_cache_manager.py
if full_sequence_must_fit:
    # First check and fail if the full request sequence won't fit.
    full_num_tokens = min(request.num_tokens, self.max_model_len)
    num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
        request_id=request.request_id,
        num_tokens=full_num_tokens,
        ...
        apply_admission_cap=True,
    )
    required_blocks = num_blocks_to_allocate + watermark_blocks
    if required_blocks > self.block_pool.get_num_free_blocks():
        return None

Source: vllm/v1/core/kv_cache_manager.py

The scheduler enables this by default via scheduler_reserve_full_isl, whose own docstring names the failure mode it prevents: “checking the first chunk… Prevents over-admission and KV cache thrashing with chunked prefill.” So the two halves of this chapter are coupled. Chunked prefill makes admission cheap to fake, and the full-sequence check makes sure the engine does not fool itself with that cheapness.

Watermark, full-sequence check, and preemption form a graded response to memory pressure — three escalating defenses ordered from cheapest and most preventive to most expensive and most reactive. The diagram below shows where each one sits relative to a new request’s allocation request. The first two run at admission time and cost nothing but some refused concurrency; preemption runs only when both have been bypassed (an already-running request, exempt from the watermark, grew until the pool was genuinely exhausted) and costs a full recompute. The design goal is to keep control near the top of this ladder and out of the expensive rung at the bottom.

flowchart TD
    A["request wants blocks"] --> B{"new admission and someone running?"}
    B -->|yes| C{"fits above the watermark reserve?"}
    B -->|no| F{"free blocks for this chunk?"}
    C -->|no| R1["reject: keep waiting (cheap, preventive)"]
    C -->|yes| D{"full sequence fits? (reserve_full_isl)"}
    D -->|no| R2["reject: would thrash later (cheap, preventive)"]
    D -->|yes| F
    F -->|yes| OK["allocate and run"]
    F -->|no| P["preempt a victim: free blocks, full recompute later (expensive)"]

Watching the valve

Because preemption is the symptom of memory pressure, it is the single most diagnostic counter the engine emits. Each _preempt_request records a PREEMPTED event; the stats layer turns each event into an increment:

# vllm/v1/metrics/stats.py
elif event.type == EngineCoreEventType.PREEMPTED:
    self.num_preempted_reqs += 1

Source: vllm/v1/metrics/stats.py

which the Prometheus logger publishes as a cumulative counter:

# vllm/v1/metrics/loggers.py
counter_num_preempted_reqs = self._counter_cls(
    name="vllm:num_preemptions",
    documentation="Cumulative number of preemption from the engine.",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

A healthy replica’s vllm:num_preemptions is flat or nearly so. A rising slope means the engine is shedding load to survive, every preemption is paying a full recompute, and effective goodput is bleeding into wasted work. The operator’s responses, in roughly increasing cost: raise the watermark to admit less aggressively, lower max_num_seqs to cap concurrency, lower max_num_batched_tokens (which also shrinks chunks and eases the cache), or add capacity. We will return to this counter in Chapter 21 as one of the first signals to read when a replica’s latency regresses, because it localizes the problem precisely: a preemption spike is unambiguously a memory-admission problem, not a kernel or a routing one.

What is unsolved

Recompute-on-preempt is a pragmatic default, not a settled answer. Throwing away a long sequence’s KV cache and rebuilding it from scratch is genuinely expensive when the prompt is short and the generated suffix is long, since the suffix cannot be prefix-cached and must be regenerated token by token — the worst case for the recompute bet. Swapping the cache to host memory instead of discarding it trades that recompute for PCIe bandwidth, and which is cheaper depends on sequence shape, link speed, and how long the request stays evicted; vLLM picks recompute as the simpler primitive, but the choice is workload-dependent and not obviously right for every deployment.

Victim selection is similarly unsophisticated. FCFS evicts the youngest running request and priority evicts the lowest-priority one, but neither asks the question that would actually minimize wasted work: which request, if preempted, costs the least to rebuild? A request whose prompt is fully prefix-cached is nearly free to resume; one deep into an uncached generation is enormously expensive. A preemption policy that weighed recompute cost against seniority could shed load far more cheaply, and nothing in the current scheduler does. And the deepest tension remains the one this chapter only managed: the chunk-size choice that smooths ITL also lengthens TTFT for long prompts, and there is no setting that escapes the curve — only settings that move you along it. Sarathi-Serve quantifies that frontier; it does not abolish it.

The natural next question is whether all this recompute is necessary at all. Much of what a preempted request rebuilds, and much of what every fresh request prefills, is shared — system prompts, few-shot exemplars, conversation history that recurs across requests. If the idle KV pool could double as a content-addressed cache, the engine could pay for a shared prefix exactly once and hand it to every request that needs it, turning the most expensive part of prefill into a lookup. That is the subject of the next chapter, and it is the single largest TTFT lever in real traffic.

Further reading

  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — Introduces chunked-prefill and stall-free batching; the formal account of the budget-slicing mechanism this chapter grounds in vLLM’s scheduler, and the source for the TTFT-vs-ITL frontier you are tuning against.

Prefix and prompt caching: paying for shared context once

By now we have an engine that pages KV memory (Chapter 4), re-plans a token-budgeted batch every step (Chapter 5), and slices long prefills into chunks while shedding load through preemption-with-recompute (Chapter 6). Each of those mechanisms made the cost of a prefill more tolerable. None of them questioned whether the prefill needed to happen at all.

It usually doesn’t. Look at the actual prompts hitting a production endpoint and you find the same bytes over and over: a 2,000-token system prompt prepended to every request, a fixed block of few-shot examples shared across a batch job, a chat history that grows by one turn but repeats the previous twelve. The compute-bound prefill that Chapter 3 told us is the expensive phase is, in real traffic, mostly recomputation of context the engine already saw seconds ago.

The insight that turns this waste into the single biggest TTFT lever you have is almost embarrassingly simple. The KV cache for a span of tokens is a pure function of those tokens (and a few side conditions we will get to). To see why, recall from Chapter 4 what a KV block actually contains: for each token in the block, the key and value vectors that token produces at every attention layer. Those vectors depend only on the token sequence up to and including that position, the model weights, and the position itself. Run the same prefix through the same model twice and you get the same numbers, down to the bit. So if two requests share a prefix, the KV blocks for that prefix are bit-identical. There is no per-request state hiding in them.

And the engine already has somewhere to keep those blocks: the idle KV pool. This is the crucial observation. When a request finishes, Chapter 4’s allocator returns its blocks to a free pool to be handed out later, but it does not zero them. A block that has been freed but not yet overwritten still physically holds valid KV state from the request that last used it. Normally that stale data is garbage we are about to clobber. But if we can find that block again by its contents before something else overwrites it, we can hand it straight to a new request and skip recomputing its prefill entirely. The free pool stops being just a free pool and becomes a content-addressed cache that happens to share storage with live allocations. The same physical block is, at any instant, either live KV for some running request, reusable cache for a future one, or scratch space about to be recycled, and the rest of this chapter is about managing those three roles on a single pool of memory.

That dual identity is the heart of the design, so it is worth seeing whole before we descend into the parts. The diagram below traces the two paths a freshly arrived request can take.

flowchart TD
    A["new request arrives"] --> B["hash its tokens block-by-block"]
    B --> C{"prefix blocks<br/>already cached?"}
    C -->|"hit"| D["reuse cached KV blocks<br/>skip their prefill"]
    C -->|"miss"| E["prefill: compute KV from scratch"]
    D --> F["compute only the uncached suffix"]
    E --> G["store new blocks in cache"]
    F --> G
    G --> H["decode normally"]
    H --> I["request finishes"]
    I --> J["blocks freed but kept as cache<br/>(not zeroed)"]
    J -.->|"later request<br/>with same prefix"| C

Naming a block by its contents

To look a block up by content, you need a name derived from content: a short, fixed-size fingerprint that two requests will compute identically if and only if their tokens match. That fingerprint is a hash. But a naive hash of one block’s 16 tokens is not enough, and seeing exactly why not is the key to this whole section.

Imagine you only hashed each block’s own tokens. Then the block holding tokens 16 through 31 would get the same name no matter what came before it in positions 0 through 15. But its KV vectors are not the same regardless of what came before, because attention at those positions looked back at the earlier tokens. Reusing that block when the earlier context differed would hand a request KV computed for someone else’s prefix. The name has to encode the entire history leading up to the block, not just the block.

vLLM solves this by chaining. It computes a name when a block fills, in get_request_block_hasher, walking the request’s tokens block-by-block and folding each block’s name into the next:

# vllm/v1/core/kv_cache_utils.py
block_hash = hash_block_tokens(
    caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys
)
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash

Source: vllm/v1/core/kv_cache_utils.py

The chaining is the whole trick. A block’s hash folds in its parent’s hash (the name of the block immediately before it), so the name of block N encodes the entire token sequence from position 0 through the end of block N, not just block N’s own 16 tokens. Each name is built from the name before it, like a tiny blockchain over the prompt. The hash itself:

# vllm/v1/core/kv_cache_utils.py
def hash_block_tokens(
    hash_function, parent_block_hash, curr_block_token_ids, extra_keys=None,
) -> BlockHash:
    if not parent_block_hash:
        parent_block_hash = NONE_HASH
    curr_block_token_ids_tuple = tuple(curr_block_token_ids)
    return BlockHash(
        hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys))
    )

Source: vllm/v1/core/kv_cache_utils.py

Two consequences fall out of the parent chaining, and both are visible in the diagram below. It shows two requests that share a system prompt (blocks 0 and 1) and then diverge: request A continues with its own question, request B with a different one. Because each block’s name is computed from its parent’s name plus its own tokens, the two requests produce identical names for blocks 0 and 1, then different names from block 2 onward where their tokens differ.

flowchart LR
    seed["NONE_HASH<br/>(seed)"] --> h0["hash 0<br/>tokens 0..15<br/>(shared system prompt)"]
    h0 --> h1["hash 1<br/>tokens 16..31<br/>(shared system prompt)"]
    h1 --> h2a["hash 2a<br/>request A question"]
    h1 --> h2b["hash 2b<br/>request B question"]
    h2a --> h3a["hash 3a<br/>request A cont."]
    h2b --> h3b["hash 3b<br/>request B cont."]

First, a cache hit on block N guarantees hits on all of blocks 0..N-1, because their contents are baked into N’s name. If request B’s hash 2b ever shows up in the cache, the only way it could have been computed is from a request that also produced hash 1 and hash 0, so those must be cached too. That is what makes prefix lookup a single forward walk rather than a search: matching the last block of a prefix proves the entire prefix matched. Second, prefixes must be identical from the very first token to match. The seed NONE_HASH feeds the very first block, so a difference anywhere in block 0 changes hash 0, which changes every name downstream. Two requests that differ only in their first token share nothing, which is exactly why you want shared context to be a literal prefix (system prompt first, then the variable part) and not, say, interleaved.

The first block has no parent, so the chain needs a seed. That is NONE_HASH, and how it’s chosen matters more than it looks:

# vllm/v1/core/kv_cache_utils.py
if hash_seed is None:
    NONE_HASH = BlockHash(os.urandom(32))
else:
    NONE_HASH = BlockHash(hash_fn(hash_seed))

Source: vllm/v1/core/kv_cache_utils.py

Within one process the seed is constant, so hashes are stable; across processes, an unset PYTHONHASHSEED gives every replica a different random seed, so block names are deliberately not portable between machines. That is fine for single-replica caching and becomes a real design question the moment you want a router to reason about which replica holds which prefix, which is Chapter 18’s problem.

Then there is extra_keys. A raw content hash would happily let one request reuse another’s KV blocks even when the KV is not actually interchangeable. generate_block_hash_extra_keys is the guard:

# vllm/v1/core/kv_cache_utils.py
extra_keys: list[Any] = (
    lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
)

Source: vllm/v1/core/kv_cache_utils.py

Three of these are worth holding onto. The LoRA name forks the cache per adapter, because the same tokens produce different KV under different fine-tuned weights (Chapter 19 leans on exactly this). The multimodal hash keys a placeholder block to the actual image or audio that filled it, so two requests with the same <image> token but different pixels don’t collide (Chapter 8 builds its encoder cache on this same hash). And cache_salt is a deliberate escape hatch: a caller who does not want their prefix shared with other tenants passes a salt that only affects the first block’s keys, poisoning the whole chain from the root and giving a clean isolation boundary without a separate cache. The honest tradeoff here is that prefix caching is, by default, cross-request and cross-user; in a multi-tenant deployment the salt is how you opt out of side-channel timing leakage, and it is opt-in, so the secure-by-default story is on you.

The longest-prefix walk

With names in hand, a new request’s cache lookup is get_computed_blocks:

# vllm/v1/core/kv_cache_manager.py
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
    self.coordinator.find_longest_cache_hit(
        request.block_hashes, max_cache_hit_length
    )
)

Source: vllm/v1/core/kv_cache_manager.py

That - 1 is small and easy to miss, but it encodes something fundamental from Chapter 3: prefill computes KV state but produces no sampled token until you run a forward pass that emits logits (the model’s score over the vocabulary for the next position). Logits come out of the forward pass on the last token of the prompt. If a request were a 100% cache hit, every token’s KV would already be sitting in cache and the engine would skip the forward pass entirely, leaving it with no logits for the next position and nothing to sample. The first decode step would have nothing to start from. So the engine caps the hit one token short of the prompt and always recomputes at least the final position, guaranteeing one real forward pass that yields the logits decode needs. Because hits are block-aligned, in practice this re-runs the whole last block, a deliberate inefficiency the code flags as future cleanup.

The walk itself, for full attention, is the loop the parent-chaining was designed for:

# vllm/v1/core/single_type_kv_cache_manager.py
for block_hash in itertools.islice(block_hashes, max_num_blocks):
    # block_hashes is a chain of block hashes. If a block hash is not
    # in the cached_block_hash_to_id, the following block hashes are
    # not computed yet for sure.
    if cached_block := block_pool.get_cached_block(
        block_hash, kv_cache_group_ids
    ):
        for computed, cached in zip(computed_blocks, cached_block):
            computed.append(cached)
    else:
        break

Source: vllm/v1/core/single_type_kv_cache_manager.py

Walk the chain, stop at the first miss. Concretely: take the request’s list of block hashes (computed exactly as in the previous section), look up the first one in cached_block_hash_to_id, the dictionary mapping a block name to the physical block currently holding that KV. If it is present, that block is a hit; record it and move to the next name. The instant a name is absent, stop. The comment states the invariant the whole design rests on, and it is the first consequence from the diagram above turned into a loop guard: a miss at block K means K+1 onward cannot be cached either (their names were derived from K’s, so if K was never stored, nothing after it could have been), so there is no point looking. The cost is linear in the number of hit blocks, one dictionary lookup each, no tree traversal and no scanning past the first miss.

The tokens those hit blocks cover are returned as num_new_computed_tokens, and here is where the cache rejoins the rest of the engine. The scheduler simply advances the request’s num_computed_tokens past them, the same counter Chapter 5 uses to track how far a request has been prefilled. To the scheduler, a request with a cache hit is indistinguishable from one whose first few prefill chunks already ran in earlier steps: it is just a request that arrives partway done. Continuous batching from Chapter 5 needs no special case at all, which is why the cache could be bolted onto the existing scheduler without disturbing it.

Keeping a freed block alive long enough to be found

The subtle part is not the lookup. It is that the blocks being looked up are simultaneously eviction candidates. This is the three-roles tension from the opening section made concrete. A cached block whose owning request has finished sits in the free queue with ref_cnt == 0 (its reference count, the number of live requests using it, has dropped to zero), available to be handed out for new allocations, and still listed in the hash map as a valid cache hit. It is in two data structures at once, playing two roles, and the two roles have to be reconciled the instant either one fires.

The state machine below is the spine of this section. A block moves between three states: live (held by one or more running requests, ref_cnt > 0), cached-and-free (ref_cnt == 0, sitting in the free queue but still findable by name), and the brief act of being recycled into a fresh live block. Two events drive the interesting transitions out of cached-and-free, and the rest of this section is just those two events: a cache hit (touch) pulls the block back to live, while an allocation under memory pressure (get_new_blocks) evicts it.

stateDiagram-v2
    [*] --> Live: allocate for a request
    Live --> Live: more requests share it, ref_cnt++
    Live --> CachedFree: request finishes, ref_cnt to 0, block kept in hash map
    CachedFree --> Live: touch, cache hit rescues it, remove from free queue
    CachedFree --> Recycled: get_new_blocks pops it for a new allocation
    Recycled --> Live: evict old hash, reset block, ref_cnt++
    Recycled --> [*]

Start with the rescue. When a hit lands on a block that is sitting in the free queue, the engine has to yank it out of that queue before it gets handed to someone else. That is touch:

# vllm/v1/core/block_pool.py
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
    for block in blocks:
        # ref_cnt=0 means this block is in the free list (i.e. eviction
        # candidate), so remove it.
        if block.ref_cnt == 0 and not block.is_null:
            self.free_block_queue.remove(block)
        block.ref_cnt += 1

Source: vllm/v1/core/block_pool.py

A hit rescues the block from the free queue and bumps its refcount, pulling it back from the brink of eviction. The FreeKVCacheBlockQueue is a hand-rolled doubly linked list precisely so this mid-queue removal is $O(1)$ (the docstring is candid that this exists to beat Python’s C-implemented deque only because it needs interior deletion). This is the “lazy LRU” the chapter title hints at: blocks are never proactively evicted on a timer or under a budget. They linger in the free list, fully reusable as cache, until the moment their physical slot is actually needed for a fresh allocation.

That moment is get_new_blocks, and the eviction is incidental:

# vllm/v1/core/block_pool.py
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
if self.enable_caching:
    for block in ret:
        self._maybe_evict_cached_block(block)
        assert block.ref_cnt == 0
        block.ref_cnt += 1

Source: vllm/v1/core/block_pool.py

You allocate by popping from the head of the free queue, and only then, lazily, do you tear down whatever cache entry the recycled block was still serving. _maybe_evict_cached_block resets the block’s hash and removes it from the lookup map, which is the Recycled --> Live edge in the state machine above: the moment of eviction and re-allocation are the same moment. The queue’s ordering is what makes this a least-recently-used (LRU) policy: freed blocks go to the tail, allocation takes from the head, so blocks that have sat unused the longest (the coldest) are reused first and recently-touched hot prefixes, near the tail, survive.

There is a deliberate refinement in which order a single finished request’s blocks rejoin the queue. They are freed “in reverse order so that the tail blocks are evicted first.” Think about which block of a shared system prompt is most valuable to keep: block 0, the root, because every future request that shares any of this prefix must hit block 0 first. The last block of the request, by contrast, holds that request’s unique tail and will likely never be hit again. Freeing in reverse pushes the unique tail blocks toward the head of the queue (evicted soonest) and leaves the shared root blocks nearer the tail (evicted last), so the part of the prefix most likely to be reused is the part the LRU protects hardest.

The payoff is measured by two counters wired straight into the scheduler’s stats:

# vllm/v1/core/kv_cache_manager.py
self.prefix_cache_stats.record(
    num_tokens=request.num_tokens,
    num_hits=num_new_computed_tokens,
    preempted=request.num_preemptions > 0,
)

Source: vllm/v1/core/kv_cache_manager.py

These surface as vllm:prefix_cache_queries and vllm:prefix_cache_hits, both counted in tokens, not requests. That denominator is the right one: it tells you what fraction of prefill tokens you avoided computing, which maps almost directly to TTFT saved. Because prefill cost is roughly linear in the number of tokens actually computed, the uncached prefill left on the critical path is just $(1 - \text{hit rate}) \times \text{prompt length}$, and the curve below shows that line falling for a 2,000-token and a 4,000-token prompt: at 80% it is the difference between a 2,000-token prefill and a 400-token one. A hit rate of 80% on a workload with long shared system prompts is not unusual.

Illustrative: assumes prefill cost is linear in tokens computed, so uncached prefill = $(1-\text{hit rate}) \times \text{prompt length}$; absolute TTFT depends on model and hardware.

Watch these two series together with the preemption counter from Chapter 6: a cache that thrashes (high queries, collapsing hits, rising preemptions) means your working set of prefixes no longer fits, and no scheduler tuning will fix a capacity problem.

Emitting the cache as a stream of events

So far everything is local to one replica. But the block lifecycle is also observable, and that turns out to matter for the fleet. When a block is cached, cache_full_blocks enqueues a BlockStored; when one is evicted, _maybe_evict_cached_block enqueues a BlockRemoved:

# vllm/distributed/kv_events.py
class BlockStored(KVCacheEvent):
    block_hashes: list[ExternalBlockHash]
    parent_block_hash: ExternalBlockHash | None
    token_ids: list[int]
    block_size: int

Source: vllm/distributed/kv_events.py

Each event carries the block’s hash and its parent’s hash. That parent pointer is not decoration: it is the same parent link the in-process hash chain used, now shipped over the wire. A stream of (hash, parent_hash) pairs is the prefix tree, serialized. Replay the stream and you rebuild the tree edge by edge, because each BlockStored tells you “this block’s name, and the name of the block it hangs off of,” which is exactly an edge. An external consumer subscribing to these events can therefore reconstruct an approximate model of which prefixes live on which replica without ever touching that replica’s memory.

The sequence below traces the payoff: a router watches the event streams from two replicas, learns that replica A cached a particular system prompt, and steers the next request carrying that prompt to A so its prefill turns into a cache hit.

sequenceDiagram
    participant RA as Replica A
    participant Router as Cache-aware router
    participant Client
    RA->>Router: BlockStored hash=H0 parent=none
    RA->>Router: BlockStored hash=H1 parent=H0
    Note over Router: rebuild prefix tree, A holds H0 to H1
    Client->>Router: request with same system prompt
    Note over Router: its prefix hashes to H0 and H1, A has them
    Router->>RA: route here for a cache hit
    RA-->>Client: fast TTFT, prefill skipped

This is the signal a cache-aware router consumes to send your request to the machine that already holds your system prompt, and we have deliberately left the wire format and the routing policy as a thread to pull in Chapter 16 (offloading evicted blocks to a larger store) and Chapter 18 (routing on this signal). For now the point is only that the engine was built to emit its cache state, not just hold it.

A different shape for the same idea: RadixAttention

vLLM’s structure is a flat hash map from block name to block, with the prefix tree existing only implicitly in the parent-chained hashes. SGLang took the other branch: make the tree explicit. Its RadixAttention keeps cached KV in a radix tree (a compressed trie) keyed on the token sequence, where shared prefixes are literally shared tree edges and the LRU eviction policy operates on tree leaves.

SGLang: Efficient Execution of Structured LM Programs (arXiv:2312.07104) introduces RadixAttention, treating the KV cache as a radix tree so prefix sharing and eviction are tree operations rather than hash-map operations.

The contrast is instructive precisely because both systems solve the same problem. vLLM’s flat map gives $O(1)$ per-block lookup and a dead-simple data structure, at the cost of fixed 16-token block granularity (the DEFAULT_BLOCK_SIZE in vllm/config/cache.py): a shared prefix that ends mid-block can’t be partially reused. A radix tree can split at arbitrary token boundaries and makes the “longest common prefix across many cached sequences” query native, which is appealing when prefixes branch heavily, as they do in structured agent programs where one prompt forks into many continuations. The cost is a more intricate structure and tree maintenance on every insert and evict. Neither is strictly better; vLLM’s bet is that block granularity is cheap enough and the flat map’s simplicity buys robustness. SGLang’s bet is that explicit tree structure pays for itself when sharing is the common case. Worth reading the paper for the tree-eviction details, which are more subtle than vLLM’s lazy-LRU and have to keep interior nodes alive while shedding leaves.

What this does and does not solve

Prefix caching is the rare optimization with almost no downside on the workloads it targets: when prefixes repeat, you pay for them once, TTFT collapses, and the only cost is hash computation you were going to do anyway for memory management. The honest caveats are about when it doesn’t help. Workloads with no shared structure (every prompt unique, no system prompt, single-turn) see hit rates near zero and pay a small hashing tax for nothing. Capacity is finite, so a large enough working set of distinct prefixes thrashes the cache into uselessness, which is the motivation for spilling cold blocks to CPU and disk in Chapter 16. The block-aligned granularity wastes the tail of every partial block, including that forced last-block recompute for logits. And multi-tenant sharing is a real side channel that the cache salt only mitigates if you remember to use it.

The deeper open problem is that prefix reuse is a property of the traffic, and the engine is at the mercy of whoever shapes prompts upstream. Put the variable part first and you have defeated the entire mechanism. Much of the value of prefix caching in practice is captured not in the engine at all but in the conventions of the API above it and the router in front of it, which is why this chapter ends pointing forward. Next we leave decoder-only text behind: Chapter 8 takes a multimodal request, runs an encoder before the LLM, and discovers that the encoder’s outputs want their own prefix cache, keyed by the very media hash we just saw threaded through extra_keys.

Further reading

  • SGLang: Efficient Execution of Structured LM Programs — arXiv:2312.07104 — RadixAttention as an explicit radix-tree alternative to vLLM’s flat hash-of-blocks, with tree-native prefix sharing and leaf eviction.
  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — the paged-block foundation (Chapter 4) that makes content-addressed block reuse possible in the first place.

Multimodal requests: the encoder cache and a second budget

Up to now every request in this book has been a sequence of token ids. The token-budget scheduler from Chapter 5 plans a batch by counting tokens; chunked prefill in Chapter 6 slices a long prompt against that single budget; prefix caching in Chapter 7 lets identical prefixes skip the prefill compute entirely, keyed by a content hash of the block’s tokens plus a few extra_keys. One budget, one cache, one currency. That tidy picture survives exactly until someone sends a picture.

A multimodal request arrives with an image (or several images, audio, video) interleaved into the prompt. The model cannot attend to raw pixels. Before the language model sees anything, a separate vision (or audio) tower has to run a forward pass that turns the media into a block of embeddings, and those embeddings are then spliced into the token stream where placeholder tokens stand in for them. This is the architecture that LLaVA: Visual Instruction Tuning (arXiv:2304.08485) made the default template: a pretrained vision encoder, a small projection into the language model’s embedding space, and an otherwise ordinary autoregressive decoder. The serving consequence is blunt. There are now two models behind one request, the encoder runs first, and its output is large, reusable, and expensive enough that you do not want to recompute it.

So the engine grows a second budget and a second cache. The decoder still spends the token budget you already understand. The encoder spends a separate compute budget, measured in encoder embeddings, and parks its outputs in a separate cache, the EncoderCacheManager. Both of those resources are arbitrated inside the same schedule() pass you read in Chapter 5, which is what makes this a scheduling problem and not just a model-loading detail. This chapter is about how those two budgets compete step by step, how the encoder cache becomes its own prefix-caching problem keyed by media hash, and how that media hash threads back into the Chapter 7 block hash so the two caches stay consistent.

The diagram below traces the path of a single multimodal request through the two models, and shows where the two budgets are spent and where the two caches sit. Pixels enter on the left, become embeddings in the encoder, get spliced into the placeholder slots in the prompt, and only then does the decoder run. Each model has a resource that gates it: the encoder is gated by the encoder compute budget plus the EncoderCacheManager, the decoder by the token budget plus the KV cache from Chapter 4.

flowchart LR
    IMG["image bytes"] --> HASH["media hash (identifier)"]
    HASH --> ENC["vision encoder (one forward pass)"]
    ENC --> EMB["block of embeddings"]
    EMB --> ECACHE["EncoderCacheManager (keyed by media hash)"]
    ECACHE --> SPLICE["splice embeddings into placeholder slots"]
    PROMPT["text tokens + placeholder tokens"] --> SPLICE
    SPLICE --> DEC["language decoder (autoregressive)"]
    DEC --> KV["KV cache (Chapter 4)"]
    EBUD["encoder compute budget"] -->|gates| ENC
    TBUD["token budget (Chapter 5)"] -->|gates| DEC

The two halves look symmetric, but their resources behave differently, and the rest of the chapter is mostly about that asymmetry: where the media hash comes from and why it has to be shared between the two caches, how the encoder’s separate budget competes with the decoder’s inside one scheduling step, and what happens to the encoder cache over the lifetime of a request.

Where the media hash comes from

Everything starts with a hash of the bytes. Before a request becomes a Request, the multimodal processor (dispatched through vllm/multimodal/registry.py) turns each media item into a MultiModalFeatureSpec, and the engine attaches a list of those specs to the request. From vllm/v1/request.py:

self.mm_features = mm_features or []

Each spec carries the modality, the processed data, the placeholder location in the prompt, and two hashes. Two hashes, because the same image can be processed under different LoRA adapters (Chapter 19), and an encoder output produced under one adapter must not be served to a request using another. The mm_hash keys the processor output (the raw, adapter-independent result of decoding the bytes); the identifier keys the encoder output and folds in a LoRA prefix when one applies, so it is the stricter of the two. The fields that matter for the rest of this chapter are in vllm/multimodal/inputs.py:

identifier: str
"""The hash for caching encoder outputs (with LoRA prefix if applicable)."""
...
mm_position: PlaceholderRange
"""The location of the `modality` tokens corresponding to this item in the prompt..."""
mm_hash: str | None = None
"""The hash for caching processor outputs (without LoRA prefix)."""

The mm_position (a PlaceholderRange) is the third piece you need to hold onto: it records where in the prompt this item’s placeholder tokens sit. A multimodal prompt is not pure media. It is a token sequence in which a contiguous run of placeholder tokens marks the slot where the image’s embeddings will be substituted, surrounded by ordinary text. The mm_position is what tells the scheduler the start and length of that slot, and it is what later lets the scheduler decide whether an image even overlaps the token range it is about to compute this step.

The hash itself is content-addressed, computed in vllm/multimodal/hasher.py. For a PIL image it serializes mode, pixels, and palette; for a tensor it serializes dtype, shape, and raw bytes; then it folds everything through blake3 (or sha256/sha512 for FIPS):

@classmethod
def hash_kwargs(cls, **kwargs: object) -> str:
    hasher_factory = _get_hasher_factory(envs.VLLM_MM_HASHER_ALGORITHM)
    hasher = hasher_factory()
    for k, v in sorted(kwargs.items(), key=lambda kv: kv[0]):
        for bytes_ in cls.iter_item_to_bytes(k, v):
            hasher.update(bytes_)
    return hasher.hexdigest()

This identifier is the join key for the whole chapter. The same image sent by two different users hashes identically, so its encoder output can be shared. And, crucially, this is the same hash that Chapter 7’s prefix cache folds into its block hash. Recall that the block hash takes extra_keys; for multimodal blocks those keys are generated in vllm/v1/core/kv_cache_utils.py:

# The block contains the current mm input. Include its offset
# relative to the start of the block so prefix-cache keys stay
# distinct when the same MM item appears at different positions
# within otherwise-identical placeholder blocks.
extra_keys.append((mm_feature.identifier, offset - start_token_idx))

That single line is the seam between the two caches. To see why it has to exist, picture two requests whose prompts are byte-for-byte identical as token ids: same text, same number of placeholder tokens, same positions. To Chapter 7’s prefix cache, which hashes block contents by token id, those two prompts look like the same prefix and would share KV blocks. But if the two requests carried different images behind those identical placeholders, sharing the blocks would be a correctness bug, because the KV computed for one image would be reused for the other. Folding the image’s identifier into the block’s extra_keys closes that hole: the decoder’s KV blocks covering placeholder tokens are a prefix-cache hit only when the image bytes also match. (The offset half of the tuple guards a subtler case: the same image appearing at a different position inside otherwise-identical placeholder blocks must still hash distinctly.)

So the media hash is doing double duty, and the diagram below traces both jobs from the single hash. Computed once from the bytes, it keys the encoder-output cache directly, and it is also injected into the decoder’s prefix-cache block hash so the KV side stays honest about which pixels a block actually saw.

flowchart TD
    BYTES["image bytes"] --> H["media hash (identifier)"]
    H --> J1["join 1: key into EncoderCacheManager"]
    H --> J2["join 2: extra_keys in KV block hash"]
    J1 --> R1["same image, two users: share one encoder output"]
    J2 --> R2["different images, same placeholder tokens: never share KV blocks"]

A second cache with its own eviction

The EncoderCacheManager in vllm/v1/core/encoder_cache_manager.py is, structurally, a smaller and simpler cousin of the KV block pool from Chapter 4. It is sized in encoder embeddings, not blocks, and it tracks who is using what:

# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
...
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []

The pattern mirrors the refcounted, lazy-LRU eviction you saw in block_pool.py. “Refcounted” means each entry tracks how many live requests point at it (the set[str] of request ids in cached); “lazy” means an unused entry is not freed at the moment its count hits zero but only later, when the space is actually needed. So an entry whose reference set is non-empty is pinned. When the last request lets go, the entry is not deleted; it slides into freeable, an OrderedDict acting as an LRU (least-recently-used) queue, and stays resident in case a future request asks for the same media. Only under memory pressure does it actually die, and only then in arrival order (the oldest, least-recently-freed entry goes first):

while num_embeds > self.num_free_slots:
    mm_hash, num_free_embeds = self.freeable.popitem(last=False)
    del self.cached[mm_hash]
    self.freed.append(mm_hash)
    self.num_free_slots += num_free_embeds

Source: vllm/v1/core/encoder_cache_manager.py

This is encoder-output prefix caching. The check_and_update_cache method is the cache lookup: if the media hash is already present, the request simply joins the reference set and the encoder is never run for that item. If the entry had been sitting in freeable, it gets rescued back into the live pool, exactly the touch-refcount-rescue move from Chapter 7. The worker side learns what to drop through get_freed_mm_hashes(), which hands the scheduler the list of hashes to tell the model runner to evict from its embedding store, the same scheduler-tells-worker contract that carries everything else in SchedulerOutput.

One subtlety worth internalizing: the docstring is explicit that the cache counts embeddings, not the placeholder tokens around them. Break tokens and text tokens interleaved between images do not consume encoder-cache slots. The two budgets really do measure different things.

The second budget, and how it competes

The compute budget is computed once at startup in vllm/v1/core/encoder_cache_manager.py:

encoder_compute_budget = max(
    scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item
)
encoder_cache_size = max(
    scheduler_config.encoder_cache_size, max_tokens_per_mm_item
)

Both default to max_num_batched_tokens (the very token budget from Chapter 5), as vllm/config/scheduler.py shows, with a floor: each must be at least one full media item, or the largest image could never be scheduled at all. So out of the box the engine carves an encoder compute allowance roughly the size of the decoder token budget, plus an encoder cache of comparable size. These are not yet user-tunable knobs (max_num_encoder_input_tokens and encoder_cache_size are init=False), which is a real limitation we will come back to.

The competition happens inside the running and waiting passes of schedule(). Each step starts the encoder budget fresh:

encoder_compute_budget = self.max_num_encoder_input_tokens

Source: vllm/v1/core/sched/scheduler.py

and for every request with media, before committing decoder tokens, the scheduler calls _try_schedule_encoder_inputs. That function is where the two budgets collide. Its contract, from vllm/v1/core/sched/scheduler.py, is that an encoder input is scheduled only if its embeddings overlap the token range about to be computed this step, it is not already cached, there is encoder compute budget left, and the encoder cache has room. The decisive moment is the rollback when an encoder input does not fit:

if not self.encoder_cache_manager.can_allocate(
    request, i, encoder_compute_budget, num_embeds_to_schedule
):
    if num_computed_tokens + shift_computed_tokens < start_pos:
        # We only schedule the decoder tokens just before the
        # encoder input.
        num_new_tokens = start_pos - (
            num_computed_tokens + shift_computed_tokens
        )
    else:
        num_new_tokens = 0
    break

Read that carefully, because it is the whole thesis in code. When the encoder budget or cache is exhausted, the scheduler does not skip the image and barrel ahead. It shrinks the decoder’s num_new_tokens down to the boundary just before the image, then breaks. The reason it must stop there rather than skip past the image is causal: the decoder cannot prefill the placeholder tokens until their embeddings exist, and the embeddings will not exist until the encoder runs, which it cannot do without budget. So the decoder is forced to halt at the wall. The two budgets are thereby coupled: running out of encoder capacity directly throttles how many decoder tokens this request gets this step. A request can stall mid-prompt, having prefilled the text up to the picture, waiting for a future step where encoder room opens up. This is the same chunked-prefill machinery from Chapter 6, now answering to a second resource constraint, and can_allocate itself will trigger eviction from freeable to try to make room before giving up.

The diagram below traces this decision for one image inside one scheduling step. The path that matters is the “no” branch out of can_allocate: instead of failing the request or jumping over the image, it clamps the decoder’s token grant to the boundary just before the placeholder run and breaks out of the loop, leaving the rest for a later step.

flowchart TD
    START["scheduling step: request with media"] --> OVL{"image embeddings overlap the token range to compute now?"}
    OVL -->|no| SKIP["leave it; nothing to schedule for this image yet"]
    OVL -->|yes| HIT{"media hash already in cache?"}
    HIT -->|yes| JOIN["join reference set; encoder is NOT run"]
    HIT -->|no| ALLOC{"can_allocate: encoder budget AND cache room? (evicts from freeable to try)"}
    ALLOC -->|yes| DEBIT["debit budget per embedding; queue encoder run + cache slot"]
    ALLOC -->|no| CLAMP["shrink decoder num_new_tokens to boundary before image, then break"]
    CLAMP --> STALL["request stalls mid-prompt; retries next step"]

When an encoder input does fit, the budget is debited per embedding and the item is queued for both compute and cache allocation:

num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(item_identifier)
encoder_inputs_to_schedule.append(i)

Source: vllm/v1/core/sched/scheduler.py

The allocation is then committed back in the main loop, where self.encoder_cache_manager.allocate(request, i) reserves the slots and encoder_compute_budget = new_encoder_compute_budget carries the debit forward to the next request in the same step. Note mm_hashes_to_schedule: it dedupes within a single step, so a prompt that repeats the same image twice pays the encoder once.

There is also a release path. The encoder output is needed only until it has been consumed into the decoder’s KV cache. Once the prefill has moved past the image, _free_encoder_inputs lets the reference go:

elif start_pos + num_tokens <= request.num_computed_tokens:
    # The encoder output is already processed and stored
    # in the decoder's KV cache.
    self.encoder_cache_manager.free_encoder_input(request, input_id)

Source: vllm/v1/core/sched/scheduler.py

After this point the embeddings live on only in freeable, available for cross-request reuse but no longer pinned. The encoder cache is therefore busiest during prefill and quiet during decode, the opposite of the KV cache’s lifetime, which is part of why it gets a separate budget rather than being folded into the token budget.

The state diagram below pulls the whole life of one encoder-cache entry together, because the three structures introduced piecemeal above (cached, freeable, freed) are really three phases of one lifecycle. An entry is pinned while any request references it. It moves to freeable, an LRU queue, when the last reference is released, where it stays resident for cross-request reuse. A future request asking for the same media rescues it straight back to pinned. Only genuine memory pressure evicts it, in arrival order, into freed, which is the list the scheduler ships to the worker so the model runner drops the embeddings from its store.

stateDiagram-v2
    [*] --> Pinned: allocate, encoder runs or cache hit
    Pinned --> Pinned: another request references same hash
    Pinned --> Freeable: last reference freed, KV has consumed it
    Freeable --> Pinned: new request asks for same media, rescue
    Freeable --> Freed: memory pressure, evict oldest LRU
    Freed --> [*]: worker drops embeddings

Variable resolution, chunked encoders, and pruning

The encoder budget would be a minor accounting trick if every image cost the same. It does not. Qwen2-VL (arXiv:2409.12191) makes resolution dynamic: a high-resolution image or a long video can expand into thousands of embeddings, while a thumbnail is cheap. The encoder budget exists precisely because the largest item can dwarf a single step’s capacity, which is why compute_mm_encoder_budget floors the budget at max_tokens_per_mm_item and why disable_chunked_mm_input exists as an escape hatch. When chunking is allowed, the scheduler can split a giant encoder input across steps; when it is disabled, the rollback in _try_schedule_encoder_inputs refuses to partially schedule a media item and instead defers the whole thing to a step that can hold it:

if (
    self.scheduler_config.disable_chunked_mm_input
    and num_computed_tokens < start_pos
    and (num_computed_tokens + num_new_tokens)
    < (start_pos + num_encoder_tokens)
):
    num_new_tokens = max(
        0, start_pos - (num_computed_tokens + shift_computed_tokens)
    )
    break

Source: vllm/v1/core/sched/scheduler.py

The reason you would ever disable chunking is in the comment one screen up in the source: encoders typically use bidirectional attention, so the whole item often wants to be processed together. That bidirectionality is what makes the encoder different from the causal decoder, and it is why the encoder budget is denominated in whole embeddings rather than a streaming token count.

Video pushes the problem further, and vLLM answers with pruning rather than pure budgeting. vllm/multimodal/evs.py (labeled “EVS” in the source) implements similarity-based token dropping: it measures the cosine similarity between adjacent video frame embeddings, keeps the first frame whole, and discards the most redundant tokens from the rest. Let $q$ be the pruning ratio, so a fraction $1 - q$ of tokens survive. With $t$ tokens per frame over $f$ frames, the retention count is the simple part, a floored fraction:

$$\text{kept} = \max!\left(t,\ \lfloor t \cdot f \cdot (1 - q) \rfloor\right)$$

total_tokens = tokens_per_frame * num_frames
evs_num_tokens = int(total_tokens * (1 - q))
min_num_tokens = tokens_per_frame
return max(min_num_tokens, evs_num_tokens)

This is a different lever on the same constraint: instead of finding budget for every embedding a long video would produce, drop the embeddings that carry no new information before they ever reach the cache. It directly shrinks both the encoder cache footprint and the number of placeholder tokens the decoder must process. The curve below plots that formula for a 16-frame clip at 256 tokens per frame: as the pruning ratio $q$ climbs, the retained token count (and with it the encoder-cache footprint) falls off linearly until it hits the one-frame floor of $t$ tokens, below which it cannot drop no matter how aggressive the pruning.

Illustrative: shape follows the chapter’s $\max(t,\ \lfloor t \cdot f \cdot (1-q) \rfloor)$ formula for chosen $t$ and $f$; real per-frame token counts and useful pruning ratios depend on the model and the video’s redundancy.

What is still rough

The honest assessment is that this subsystem is younger than the KV path. Three open edges stand out.

First, the budgets are not configurable. Both max_num_encoder_input_tokens and encoder_cache_size are derived from max_num_batched_tokens and carry TODO comments asking to expose them. An operator who knows their traffic is image-heavy cannot yet hand more memory to the encoder cache without inflating the decoder token budget too, which couples two things that should be tunable apart. This is exactly the kind of contention Chapter 21 will teach you to read off the metrics.

Second, the encoder cache is per-replica and ephemeral. There is an ec_connector path in the scheduler (the has_cache_item / external_load_encoder_input branches) that mirrors the KV-connector abstraction from Chapter 16, hinting that encoder outputs could be offloaded or shared across replicas the way KV blocks are. But cross-replica encoder-output sharing is far less mature than cross-replica prefix caching, and a fleet-level router (Chapter 18) has no clean signal yet for “which replica already encoded this image.”

Third, the two-cache consistency rests entirely on that one extra_keys line. The media hash has to be a faithful function of the bytes; the EXIF-ImageID shortcut in the hasher, which trusts a UUID embedded in the image rather than rehashing pixels, is a small reminder that content addressing is only as sound as the content you choose to address. Get the hash wrong and you either lose reuse or, worse, serve one user’s image embeddings keyed under another’s prompt.

The mechanism is sound and the through-line is clean: a second model means a second budget and a second cache, both arbitrated inside the same step loop, both keyed by a media hash that also keeps the decoder’s prefix cache honest. With the encoder cache in place, every input to the language model, text or pixels, is now a sequence of embeddings the engine knows how to schedule and reuse. The next chapter drops below the scheduler entirely, into the attention kernels that have to read all of those embeddings, paged and ragged, at the bandwidth roofline.

Further reading

  • LLaVA: Visual Instruction Tuning — arXiv:2304.08485 — establishes the vision-encoder-then-LLM template that defines the serving shape of a multimodal request.
  • Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution — arXiv:2409.12191 — variable-resolution inputs whose wildly different per-item costs are exactly what the encoder budget and chunked encoder scheduling exist to manage.

Attention kernels: from PagedAttention to FlashAttention, FlashInfer, and MLA

In Chapter 4 we made the KV cache paged. We chopped each sequence’s keys and values into fixed 16-token blocks, scattered those blocks anywhere in a global pool, and kept a per-request block table mapping logical positions to physical blocks. That bought us the concurrency that makes a replica economical. But it also quietly broke something. A textbook attention kernel wants its keys and values laid out contiguously: one flat [seq_len, num_heads, head_size] tensor it can stream linearly. Paging hands the kernel the opposite of that — a sequence whose KV lives in blocks that are physically discontiguous and, across a batch, of wildly different lengths. The layout is only half the design. The other half is a kernel that can attend over ragged, block-scattered sequences and still hit the memory-bandwidth roofline we established in Chapter 3.

This chapter is about that kernel, and about where it lives in vLLM. The punchline up front: there is no single attention kernel. There is a dispatch seam — an opaque custom op — and behind it a zoo of backends, each tuned for a hardware generation, a precision, or a model architecture. The most consequential of those architectures is Multi-head Latent Attention, which doesn’t just need a different kernel. It changes the KV cache you sized in Chapter 4 into a single compressed latent vector per token, and it’s how the flagship 2026 open models actually run.

Before diving in, it helps to fix the vocabulary the rest of the chapter leans on. Attention is the operation where each query token computes a weighted sum over the keys and values of every token it can see; the weights come from a softmax over query-key dot products. Prefill is the phase that processes all the prompt tokens at once (many queries), and decode is the phase that generates one token at a time (one query attending over a long history). A kernel is a single GPU program launch; the cost we keep returning to is whether that launch is bandwidth-bound (limited by how fast it can read the KV cache out of memory) or launch-bound (limited by the fixed CPU and scheduling overhead of starting the kernel at all). The whole chapter is the story of how vLLM keeps the attention kernel bandwidth-bound while feeding it the awkward, paged, ragged input that real serving produces.

The dispatch seam

Start with the indirection, because it explains everything downstream. A model’s attention layer never calls FlashAttention, or FlashInfer, or any kernel by name. It calls a registered PyTorch custom op — a function that PyTorch knows about by name but whose insides are sealed off from the compiler. Here is the op, in vllm/model_executor/layers/attention/attention.py:

def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: LayerNameType,
    ...
) -> None:
    ...
    layer_name = _resolve_layer_name(layer_name)
    attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)

    self.impl.forward(
        self, query, key, value, kv_cache, attn_metadata, output=output, ...)

Two things matter here. First, the op is registered with direct_register_custom_op (same file), which means torch.compile treats it as a black box — it will not trace into it. That opacity is exactly what lets the compiler split the model graph at attention boundaries, which is the mechanism Chapter 10 uses to capture everything between attentions as CUDA graphs while leaving attention itself eager. The dispatch seam and the cudagraph story are the same seam. Second, self.impl is a concrete backend object: it is the actual kernel wrapper, resolved exactly once, at startup, by get_attn_backend in vllm/v1/attention/selector.py. Think of self.impl.forward as a function pointer that was bound long before the first token ever arrived, so the per-step hot path pays no selection cost — it just calls through.

How does that binding get made? get_attn_backend builds an AttentionSelectorConfig describing the shape of the problem — the head size, the dtype, the KV-cache dtype — together with a pile of feature flags that capture which kind of attention this is: use_mla (is this a latent-attention model?), has_sink (does it use attention sinks?), use_sparse (top-k sparse attention?). It hands that config to the platform layer, which knows the GPU generation and returns a class path — a string naming the backend to load. The names live in an enum in vllm/v1/attention/backends/registry.py:

    FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
    ...
    FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
    ...
    TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
    CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
    FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"

That is the whole map. A name, a class path, lazily imported. The selection is cached (@cache on _cached_get_attn_backend), and the chosen backend can even demand a specific physical KV layout — get_required_kv_cache_layout() may call set_kv_cache_layout, so the kernel reaches back and dictates how Chapter 4’s blocks are arranged in memory. The kernel is not a passive consumer of the cache. It co-designs it.

The diagram below traces the two timescales of this seam. The left column is what happens once, at startup: the selector reads the config, picks a backend, and binds it into self.impl. The right column is what happens on every forward pass: the model calls the custom op, which looks up the already-bound backend and dispatches into it. Notice that the only edge crossing from startup into the hot path is the binding of self.impl — everything expensive about selection has been hoisted out of the per-step loop.

flowchart TD
    subgraph startup["once, at startup"]
        cfg["AttentionSelectorConfig (head size, dtype, kv dtype, use_mla, has_sink, use_sparse)"] --> sel["get_attn_backend in selector.py"]
        sel --> plat["platform layer picks a class path from registry.py"]
        plat --> impl["self.impl = concrete backend object"]
    end
    subgraph hot["every forward pass (eager)"]
        layer["model attention layer"] --> op["unified_attention_with_output (custom op, opaque to torch.compile)"]
        op --> ctx["get_attention_context: look up attn_metadata + kv_cache"]
        ctx --> call["self.impl.forward(...)"]
    end
    impl -. binds .-> call
    call --> backend["chosen kernel: FlashAttention / FlashInfer / one of the MLA backends"]

FlashAttention: where “paged” actually lives now

The default backend on NVIDIA is FlashAttention. Its job each step is two-phase: first write this step’s freshly-computed keys and values into the paged cache so they are visible to future tokens, then attend over the whole cache. Why two phases and not one? Because the keys and values a token produces this step must be stored before that same step’s query reads them, and every later step will reread them too. The write is what makes the cache grow by one token’s worth of state per step; the attend is what consumes the whole accumulated cache. In current vLLM these are two separate custom ops — the write was pulled out into its own do_kv_cache_update method. Since torch.compile cannot see inside either op, it has no way to know that the attend depends on the write having happened first, so vLLM threads a dummy tensor argument (kv_cache_dummy_dep) from the write into the attend purely to manufacture that data dependency and pin the ordering. The work is the same as a single fused phase; the split just makes the ordering explicit to the compiler.

The write is a scatter, not a copy. A copy would lay the new tokens down contiguously; a scatter sends each token to a different, possibly far-apart, physical slot, because paging means consecutive logical positions can live in any physical block. Here it is, in vllm/v1/attention/backends/flash_attn.py:

        reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

slot_mapping is the per-token list of physical slots produced by the block table from Chapter 4 — for each token in this batch, the single integer address where its KV entry belongs in the global pool. The op reshapes the new K/V and writes each token to its slot, fusing in the FP8 KV-cache scaling we’ll meet in Chapter 12 (the _k_scale and _v_scale arguments). Note the comment in the source: it deliberately does not slice key[:num_actual_tokens], because the op infers the token count from slot_mapping’s shape — a micro-optimization that tells you how cost-sensitive this path is.

Then the attention itself:

                flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    ...
                    block_table=block_table,
                    ...
                )

Source: vllm/v1/attention/backends/flash_attn.py

This single call is where PagedAttention lives in 2026. There is no separate “paged attention kernel” on the hot path anymore. flash_attn_varlen_func is FlashAttention’s variable-length kernel, and two of its arguments do the heavy lifting. The first is cu_seqlens_q, the cumulative sequence lengths of the queries — this is what “varlen” means. Imagine three requests in a batch with query lengths 7, 1, and 2000. Rather than pad them all to 2000 and waste the kernel’s time on padding, vLLM concatenates them end to end into one flat buffer of length 2008 and hands the kernel the running offsets [0, 7, 8, 2008]. The kernel reads those offsets to know where each sequence starts and stops, so a length-7 sequence and a length-2000 sequence ride the same launch with zero padding waste. The second is block_table, the per-request mapping from logical block to physical block from Chapter 4; passing it in is what makes the kernel walk the scattered cache instead of a contiguous tensor. Paging became a feature of the FlashAttention kernel rather than a kernel of its own. This is the practical payoff of the FlashAttention-2 work (arXiv:2307.08691): by reorganizing the attention computation to maximize work per memory pass and parallelize across sequence length and the batch, it keeps the kernel bandwidth-bound rather than launch-bound, which is precisely what a continuous-batching engine that re-plans every step needs.

The varlen formulation also explains why the scheduler from Chapter 5 could collapse prefill and decode into one batch: a decode step is just a varlen sequence with query length 1, and a prefill chunk is one with query length N. The kernel does not care which is which. It reads cu_seqlens_q and goes.

The diagram below traces a single FlashAttention step end to end. Follow the new query/key/value tensors in from the top: the keys and values branch off to be scattered into the paged cache via slot_mapping, and then the attend phase reads the query plus the entire updated cache — packed by cu_seqlens_q, walked through block_table — to produce the output. The cache write must finish before the attend reads, which is the ordering the two-op split exists to guarantee.

flowchart TD
    qkv["new q, k, v for this step's tokens"] --> q["query buffer"]
    qkv --> kv["key, value buffers"]
    kv --> scatter["reshape_and_cache_flash: scatter each token to its slot"]
    slot["slot_mapping: physical slot per token"] --> scatter
    scatter --> cache["paged KV cache (updated)"]
    q --> attend["flash_attn_varlen_func"]
    cache --> attend
    cu["cu_seqlens_q: pack ragged batch, no padding"] --> attend
    bt["block_table: logical to physical blocks"] --> attend
    attend --> out["attention output"]

One warning in this file is worth internalizing, because it reframes how you read every backend forward in vLLM:

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.

Source: vllm/v1/attention/backends/flash_attn.py

Because attention runs eager in the gaps between captured graphs (Chapter 10), the Python in this method is on the critical path of every decode step. A stray .view() that touches no GPU memory can still cost you tail latency at high QPS. This is the kind of thing a distributed-systems engineer recognizes instantly: it is the per-request CPU overhead problem, just relocated into a kernel wrapper.

The original PagedAttention CUDA kernels still exist — csrc/libtorch_stable/attention/paged_attention_v1.cu and paged_attention_v2.cu, both Apache-licensed and carrying their FasterTransformer lineage in the header — but they are off the main path now, kept as a reference implementation and a fallback. Reading paged_attention_v2.cu is still worthwhile for your performance instincts: v2 splits each sequence’s KV across thread blocks and does a second reduction pass, the classic split-K trick for keeping the GPU busy when a single long sequence would otherwise starve it. That is the same problem FlashAttention-2 solves more generally.

FlashInfer and cascade attention

FlashInfer is the alternative NVIDIA backend, and the reason to reach for it is in vllm/v1/attention/backends/flashinfer.py. It carries a special metadata field:

    use_cascade: bool
    """
    If True, the entire batch is a cascade attention call, and the
    `prefill` and `decode` fields will both be None.
    """

Cascade attention is the kernel-level partner of the prefix caching from Chapter 7. When many requests in a batch share a long common prefix — a system prompt, a few-shot block — naive attention reads that shared KV once per request. If sixteen requests all share a 1000-token system prompt, that prompt’s keys and values get streamed out of memory sixteen times per step, and since decode is bandwidth-bound (Chapter 3) that rereading is the cost.

The curve below makes the divergence concrete for that 1000-token shared prefix: naive attention rereads the prefix once per request, so its per-step prefix traffic grows linearly with batch size, while cascade reads the prefix exactly once no matter how wide the batch fans out. At a batch of sixteen that is the difference between 16000 and 1000 prefix-token reads per step.

Illustrative: counts follow exactly from the chapter’s 1000-token-prefix example (naive = prefix x batch, cascade = prefix); a real batch’s prefix length and fan-out vary.

Cascade fixes this by exploiting a property of softmax attention: an attention result over a long context can be computed in pieces and stitched back together exactly, as long as you carry along each piece’s softmax normalization statistics (its running max and sum) so the merge reweights the pieces correctly. This is the same online-softmax trick that makes FlashAttention itself tile-able. Cascade applies it across the batch dimension. It splits the computation into two levels. Level one attends every query in the batch over the shared prefix, reading that prefix’s KV exactly once for the whole batch. Level two attends each request’s queries over only that request’s unique suffix. Then it merges the two partial results with a numerically-correct softmax combine. The shared prefix went from sixteen rereads to one.

At dispatch the kernel is invoked through a planned wrapper:

        if attn_metadata.use_cascade:
            # Cascade attention (rare case).
            assert attn_metadata.cascade_wrapper is not None
            output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
            return output

Source: vllm/v1/attention/backends/flashinfer.py

The diagram below contrasts the naive path with the cascade path for a batch that shares a prefix. The point to read off it is the fan-in on the shared-prefix box: in the naive case every request’s attend reaches into it (three arrows, three rereads), while in the cascade case a single level-one attend reads it once and feeds the merge.

flowchart LR
    subgraph naive["naive: shared prefix reread per request"]
        r1["request 1 query"] --> p1["attend over prefix + suffix 1"]
        r2["request 2 query"] --> p2["attend over prefix + suffix 2"]
        r3["request 3 query"] --> p3["attend over prefix + suffix 3"]
        p1 --> pref["shared prefix KV in memory"]
        p2 --> pref
        p3 --> pref
    end
    subgraph cascade["cascade: shared prefix read once"]
        allq["all queries in batch"] --> shared["level 1: attend over shared prefix once"]
        q1["request 1 query"] --> s1["level 2: attend over suffix 1"]
        q2["request 2 query"] --> s2["level 2: attend over suffix 2"]
        q3["request 3 query"] --> s3["level 2: attend over suffix 3"]
        shared --> merge["softmax-correct merge of partial results"]
        s1 --> merge
        s2 --> merge
        s3 --> merge
    end

The “rare case” comment is honest: cascade only pays off when the shared prefix is long and the batch fan-out is wide, and the metadata builder gates it on a measured common_prefix_len. It is honest in another way too — at the time of writing the FlashInfer backend’s use_cascade_attention actually returns False with a “doesn’t work, disable it for now” note, so the path is dormant in practice. But the mechanism is the point, and when it does fire it turns prefix sharing from a memory win (Chapter 7 saved you the recompute) into a bandwidth win (cascade also saves you the reread). This is the broader thesis of the FlashInfer paper (arXiv:2501.01005): attention serving has many shapes — paged decode, ragged prefill, shared-prefix cascade — and a serving engine wants a single customizable engine that compiles the right kernel for each, rather than a pile of hand-written specializations. Read it for the block-sparse formulation that unifies paging and cascading under one abstraction.

MLA: the KV cache becomes a latent vector

Everything so far assumed the KV cache you sized in Chapter 4: two tensors, keys and values, one entry per attention head per token. For a model with many KV heads, that cache is large, and since decode is bandwidth-bound (Chapter 3), every byte of it is reread on every step. The bytes are the bottleneck, so the obvious lever is to store fewer of them.

Multi-head Latent Attention, introduced with DeepSeek-V2 (arXiv:2405.04434), does exactly that, and it attacks the cache itself rather than the kernel around it. The key idea is that the per-head keys and values, across all heads, are highly redundant — they can be reconstructed from a much smaller shared representation. So instead of caching per-head keys and values, MLA caches a single low-rank latent vector per token. “Low-rank” here means the latent lives in a much smaller dimension than the full per-head K and V it stands in for. When attention needs the actual per-head K and V, the kernel reconstructs them on the fly by multiplying the latent through fixed up-projection matrices (matrices learned at training time that expand the compact latent back out to full head dimensions). The trick that makes this cheap at serve time is that those up-projection matrices can be algebraically folded into the query and output projections, so the kernel often never has to materialize the full K and V at all.

The cleanest evidence is the cache shape. Recall the FlashAttention layout was [num_blocks, 2, block_size, num_kv_heads, head_size]. The MLA layout, from vllm/model_executor/layers/attention/mla_attention.py:

    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,  # assumed to be 1 for MLA
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        return (num_blocks, block_size, head_size)

There is no 2 (no separate K and V — the latent carries both) and no head dimension (num_kv_heads is 1). And the head_size is not a head size in the usual sense:

        self.head_size = kv_lora_rank + qk_rope_head_dim

Source: vllm/model_executor/layers/attention/mla_attention.py

So a cached MLA token is the latent itself (kv_lora_rank, the compressed part) plus a small slice that carries rotary position information (qk_rope_head_dim), packed into one vector. For DeepSeek-V3 those are 512 and 64, so each cached token is a single $512 + 64 = 576$-element vector. Contrast that with a conventional model caching, say, 8 KV heads of 128 dims each across two tensors — $2 \times 8 \times 128 = 2048$ elements per token. Against this modest 8-head baseline that is already a 3.6x shrink per token, and against a wide-headed model it reaches the order of magnitude that makes MLA worth the trouble. This is exactly the forward-pointer Chapter 4 flagged: not every model has a fat growing KV cache. MLA’s is compressed, and because decode rereads the whole cache every step, a smaller cache is a proportionally cheaper reread — that directly buys back the concurrency and context length that decode bandwidth was spending.

The curve below traces per-sequence cache size (one attention layer, fp16) against context length for both layouts. Because cache size is linear in context length, MLA’s smaller per-token footprint shows up as a flatter line: at 128k tokens this 8-head baseline holds roughly 524 MB per layer where MLA holds about 147 MB, and the gap only widens with wider conventional heads.

Illustrative: bytes computed from the chapter’s element counts (2048 vs 576 elements/token x 2 bytes fp16) for a single attention layer; the conventional curve scales with that model’s KV-head count.

The cost is moved into compute, and the design note at the top of mla_attention.py is unusually candid about the tradeoff. MLA has two computation paths:

MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach. We generally want to use the compute friendly
approach for "prefill" ... and the data-movement friendly approach for "decode"

Why split into two paths at all? Because prefill and decode sit on opposite sides of the roofline from Chapter 3, and the latent gives you a knob that lets you choose where the work lands. In prefill you have many query tokens against a comparable number of keys, so there is plenty of arithmetic to amortize a memory pass — you are compute-bound. There it pays to up-project the latents back to full per-head K and V once and run ordinary multi-head attention; this is the compute-friendly path, and “data-movement friendly” would be wasted because data movement is not the bottleneck. In decode you have one query token attending over a long context, so there is very little arithmetic and the cost is dominated by reading the cache — you are bandwidth-bound. There it pays to not up-project: keep everything in the compressed latent space and run what amounts to multi-query attention with one fat head. That spends more FLOPs per token (the up-projection math gets folded into the per-token work) but touches far fewer bytes, which is exactly the right trade when bytes are the bottleneck. The same model thus uses two different kernel strategies depending on which phase the scheduler put a token in.

The diagram below shows both paths diverging from the one shared latent cache. Read it as a fork: the cache is identical for both phases, but prefill walks left into the up-project-then-attend path while decode walks right into the stay-compressed path.

flowchart TD
    cache["latent KV cache: one 576-dim vector per token (kv_lora_rank + qk_rope_head_dim)"] --> fork{"which phase?"}
    fork -->|"prefill: many queries, compute-bound"| up["up-project latent to full per-head K, V"]
    up --> mha["ordinary multi-head attention (compute-friendly)"]
    fork -->|"decode: one query, bandwidth-bound"| keep["keep latent compressed, fold up-projection into the math"]
    keep --> mqa["multi-query attention with one fat head (data-movement-friendly)"]
    mha --> out["attention output"]
    mqa --> out

That is why the mla/ directory is a family of backends — flashmla.py, cutlass_mla.py, flashinfer_mla.py, triton_mla.py — each pairing a prefill path with a decode path tuned for a specific GPU. The selector’s use_mla flag is what routes a DeepSeek model down this entire parallel universe.

There is one more turn of the screw. The 2026 flagship models push toward sparse MLA. The observation is that even after compressing each token, decode still rereads every token’s latent on every step, and most of those tokens contribute almost nothing to the result. So instead of attending over the whole context, a lightweight indexer scores the cached tokens against the current query and selects only the top-k most relevant ones — say the 2048 that matter out of a 128k context — and attention runs over just those. vLLM has backends for it — flashmla_sparse.py, flashinfer_mla_sparse.py — fed by an indexer.py (which produces the per-request top-k choices) and the Triton helpers in sparse_utils.py that translate those per-request top-k indices into global cache positions the kernel can fetch. The FlashMLA backend even reorders the batch around this: reorder_batch_threshold: int = 128 # process small prefills with decode pathway, a comment that captures how blurry the prefill/decode line has become once the kernel, not the scheduler, owns the distinction. Sparse attention over a compressed cache is a credible answer to long-context decode cost, and it is still actively moving — the right kernel and the right top-k policy are not settled.

What is solved, what is not

Step back and the shape is clear. Paged layout (Chapter 4) and the attention kernel are not independent. The layout is a contract, and the kernel both honors it and, increasingly, dictates it — varlen indexing for ragged batches, block tables for paging, cascade for shared prefixes, and for MLA a wholesale redefinition of what a cache entry even is. The custom-op seam is what lets vLLM swap among these without the model code knowing, and what lets the compiler in Chapter 10 capture everything around them.

What remains genuinely unsolved is selection and portability. The right backend depends on GPU generation, precision, head dimensions, and model family, and the matrix is sparse and shifting — a kernel that is fastest on Hopper may not exist on Blackwell yet, and MLA’s two-path design means a regression can hide in just the decode kernel of just one backend. The eager-mode CPU-overhead warning is a standing tax that every new backend must pay attention to. And sparse MLA is early: top-k attention trades exactness for bandwidth, and how much that costs in quality at long context is still being measured in public.

With the kernel understood as an opaque, dispatched, eager-in-the-gaps op, Chapter 10 can take the next step: capturing everything between these ops as CUDA graphs to kill the per-step launch overhead that, on a one-token decode, can otherwise outrun the GPU itself.

Further reading

  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — arXiv:2307.08691 — reorganizes attention to maximize work per memory pass and parallelize across sequence length, the basis of the varlen kernel vLLM calls by default.
  • FlashInfer: Efficient and Customizable Attention Engine for LLM Serving — arXiv:2501.01005 — a block-sparse abstraction that unifies paged decode, ragged prefill, and shared-prefix cascade into one compiled engine.
  • DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — arXiv:2405.04434 — introduces Multi-head Latent Attention, which compresses the KV cache into a single low-rank latent vector per token.

Killing per-step overhead: CUDA graphs and async scheduling

By now the engine in Part II looks formidable. The token-budget scheduler from Chapter 5 re-plans the batch every step; the paged attention backends from Chapter 9 attend over ragged, block-scattered sequences at something close to bandwidth roofline. Each decode step does a real amount of GPU work. And yet, on a fast accelerator running a modestly sized model, you can profile a steady-state decode loop and find the GPU sitting idle for a meaningful fraction of every step, waiting on the CPU.

This is the embarrassing part of inference serving that nobody warns you about. In decode, every step produces one token per sequence, so the actual matmuls and attention calls are short. But before any of them can run, Python has to walk the model’s forward method, dispatch each operator through PyTorch’s eager-mode machinery, and ask the CUDA driver to launch each kernel. That is dozens of kernels per layer times dozens of layers, each launch costing microseconds of pure host-side overhead, so a single decode step can spend hundreds of microseconds of host time just issuing work. Stack enough and the launch of the work outlasts the work itself: the GPU finishes a tiny GEMM and stalls, idling, while the host queues up the next one.

To make the overhead concrete, picture the host-to-device pipeline for one decode step. The CPU is the producer: it walks Python, dispatches operators, and enqueues kernels onto a CUDA stream, which is just an ordered queue of work the GPU drains. The GPU is the consumer: it pulls kernels off the stream and runs them. Each kernel does microseconds of compute but also costs microseconds of host time to enqueue. When the compute is large (prefill), the GPU is the bottleneck and the host’s enqueueing hides behind it. When the compute is tiny (decode), the roles invert: the host cannot enqueue kernels fast enough to keep the GPU fed, so the GPU drains the stream, finds it empty, and waits. The diagram below traces both regimes.

flowchart LR
    P["CPU producer: walk forward, dispatch ops, enqueue kernels"] -->|"CUDA stream"| Q["GPU consumer: drain stream, run kernels"]
    Q --> R{"kernel size vs enqueue rate"}
    R -->|"prefill: kernels big, GPU is the bottleneck, host hides behind it"| OK["GPU stays busy"]
    R -->|"decode: kernels tiny, host can't keep up, stream runs dry"| IDLE["GPU idles, waiting on host"]

The crossover the diagram describes is easier to feel as a curve. The host time to enqueue a decode step is roughly fixed: it is the same procession of per-kernel launches regardless of how many sequences ride along in the batch. The GPU compute, by contrast, grows with the batch. The curves below plot both against batch size. Where GPU compute sits below the flat host-launch line (small batches, the decode regime), the host is the bottleneck and the GPU idles in the gap between the two; where it climbs above (large batches, toward prefill-like work), the GPU is busy enough to hide the launches behind it.

Illustrative: the flat host line and linear-ish GPU curve have the right shapes and the crossover is real, but the absolute microseconds and the exact crossover batch depend on the model, the accelerator, and the kernel mix.

You know the shape of this from systems work. It is the per-request fixed cost that dominates once the variable cost gets small enough, and decode made the variable cost small. This chapter is two complementary attacks on the fixed cost. The first, CUDA graphs, removes the launch overhead within a step by recording the whole sequence of kernels once and replaying it with a single launch call. The second, async scheduling, removes the scheduling overhead between steps by overlapping the CPU’s planning of step N+1 with the GPU’s execution of step N. Both are, in vLLM, a careful exercise in identity and timing: a CUDA graph only replays correctly if the data sits at exactly the addresses it was recorded against, and async scheduling only works if the engine can advance to the next step before it has even learned what the last step produced.

A step is a static shape, most of the time

Here is the observation that makes CUDA graphs viable at all. A CUDA graph is a recording of a stream of GPU operations that can be replayed with a single launch call, amortizing all the per-kernel host overhead. Recording happens once: you run the model normally inside a capture context, and instead of executing, CUDA writes down the exact sequence of kernels, their launch parameters, and the device addresses they read and write. Afterwards, one replay() call re-issues that entire recorded sequence to the GPU without the host walking the model again. The catch is in the word exact: a graph records exact operations on exact memory addresses. You cannot replay a graph captured for a batch of 32 sequences against a batch of 47; the tensor shapes, and therefore the kernel launch dimensions, are baked into the recording.

So CUDA graphs are only viable if a step’s shape repeats often enough that one recording serves many steps. Decode is friendly to this. In a uniform decode step, every running sequence contributes exactly one query token, so the input is a tensor of shape [num_seqs, ...] and only two things vary step to step: num_seqs (how many sequences are decoding right now, which drifts as requests arrive and finish) and the KV-cache lengths (how far each sequence has progressed, which grows by one every step). vLLM neutralizes each of these so the recording stays valid. It handles the varying num_seqs by capturing a separate graph for each of a fixed set of batch sizes and padding any real batch up to the next captured size, so the graph always sees a shape it recorded. It handles the varying KV-cache lengths by keeping the variable-length attention computation outside the graph entirely. That second point is the crux, and it is why vLLM does not just wrap the whole model in one graph: attention is precisely the operator whose work depends on those ever-changing lengths, so it is the one part that cannot be frozen into a static recording. The next section shows how vLLM carves attention out.

Splitting the graph at attention

When vLLM compiles a model with torch.compile under its custom backend, torch.compile first traces the model into an FX graph — a flat, ordered list of the operators the forward pass executes, the same intermediate representation you would get from any PyTorch tracing. vLLM does not hand that whole FX graph to a single CUDA-graph capture. It first splits the graph at the attention operators. The list of ops that constitute a split boundary lives in the compilation config:

# vllm/config/compilation.py
_attention_ops: ClassVar[list[str]] = [
    "vllm::unified_attention_with_output",
    "vllm::unified_mla_attention_with_output",
    "vllm::mamba_mixer2",
    "vllm::mamba_mixer",
    ...
    "vllm::deepseek_v4_attention",
]

Source: vllm/config/compilation.py

That is the same unified_attention_with_output custom op you met in Chapter 9 as the dispatch seam into the attention backend. Here it doubles as a fence: a marker that says “cut the graph here.” The backend’s VllmBackend.__call__ resolves splitting_ops (which defaults to this _attention_ops list) and calls split_graph, which walks the FX nodes in order and assigns each one a subgraph id, bumping the id every time it crosses a splitting op so that attention ends up alone in its own segment and the dense ops cluster into the segments between:

# vllm/compilation/backends.py
if should_split(node, splitting_ops):
    subgraph_id += 1
    node_to_subgraph_id[node] = subgraph_id
    split_op_graphs.append(subgraph_id)
    # keep consecutive splitting ops together
    if should_split(node.next, splitting_ops):
        subgraph_id -= 1
    else:
        subgraph_id += 1
else:
    node_to_subgraph_id[node] = subgraph_id

Source: vllm/compilation/backends.py

The result is an alternating structure: a compilable subgraph of everything between attention calls (QKV projection, MLP, layernorms, residual adds, MoE routing), then the attention op running eagerly, then the next compilable subgraph, and so on. vLLM calls this piecewise compilation, and the matching CUDA-graph mode is PIECEWISE. The diagram below shows the layout for two transformer layers: the dense segments (green) are each wrapped in their own CUDA graph and replayed with one launch, while the attention ops (orange) run in eager mode in the gaps between them.

flowchart LR
    A["dense segment: QKV proj, layernorm (CUDA graph)"] --> B["attention (eager)"]
    B --> C["dense segment: out proj, MLP, residual (CUDA graph)"]
    C --> D["attention (eager)"]
    D --> E["dense segment: MLP, layernorm (CUDA graph)"]
    style A fill:#22863a,color:#fff
    style C fill:#22863a,color:#fff
    style E fill:#22863a,color:#fff
    style B fill:#d97706,color:#fff
    style D fill:#d97706,color:#fff

The reason for keeping attention out of the graph is exactly Chapter 9: attention is the one operation whose work depends on the content of the batch (sequence lengths, block tables), and its backends often allocate scratch, run variable-iteration kernels, or branch on metadata in ways that do not capture cleanly. A recording freezes a fixed kernel shape, but attention’s per-step work is intrinsically variable, so freezing it would record the wrong amount of work. So vLLM replays the dense, shape-stable parts and leaves attention eager in the gaps. The launch overhead it eliminates is the overhead of the dense segments, which is where the vast majority of the per-step kernels live; the handful of eager attention launches per layer are cheap by comparison.

Each non-splitting subgraph becomes a PiecewiseBackend compiled for the relevant shapes, and wrap_with_cudagraph_if_needed wraps it in a CUDAGraphWrapper tagged PIECEWISE:

# vllm/compilation/backends.py
return static_graph_wrapper_class(
    runnable=piecewise_backend,
    vllm_config=vllm_config,
    runtime_mode=CUDAGraphMode.PIECEWISE,
    cudagraph_options=CUDAGraphOptions(
        debug_log_enable=is_first_graph,
        gc_disable=not is_first_graph,
        weak_ref_output=is_last_graph,
    ),
)

Source: vllm/compilation/backends.py

Note the per-graph options. Only the first subgraph logs (capturing happens for many shapes, and you do not want a log line each time), and garbage collection is disabled for all but the first graph because running gc.collect() between every layer’s capture would make capture pathologically slow.

Capture, replay, and the iron law of identity

The wrapper itself, in vllm/compilation/cuda_graph.py, is small and ruthless about memory. It is a tiny state machine per shape: the first call for a given shape records, and every subsequent call with that shape replays. The state diagram below traces that lifecycle.

stateDiagram-v2
    [*] --> NotCaptured
    NotCaptured --> Capturing: first call for this shape
    Capturing --> Captured: record kernels and input addresses
    Captured --> Captured: later call same shape, replay (one launch)
    note right of Capturing
        save data_ptr of every input tensor
    end note
    note right of Captured
        replay re-runs recorded kernels,
        which dereference the saved addresses
    end note

The capture path is the interesting one:

# vllm/compilation/cuda_graph.py
input_addresses = [
    x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
...
with torch.cuda.graph(cudagraph, pool=self.graph_pool, stream=current_stream()):
    output = self.runnable(*args, **kwargs)

Read the capture path together with the replay behavior to see why addresses matter so much. At capture time the wrapper records data_ptr() — the raw device address — of every input tensor, then runs the model once inside torch.cuda.graph(...) so CUDA writes down the kernel sequence. A captured graph reads from and writes to the exact device addresses that were live at that moment. Crucially, replaying it does not re-read whatever Python objects you pass as arguments next time; it re-runs the recorded kernels, and those kernels dereference the raw pointers they were recorded against. If the data you want this step to process sits at a different address than last step, replay will silently process whatever happens to be at the old address instead. This is the single most important fact about CUDA graphs and the source of nearly every bug people hit with them: the inputs must live at the same addresses on every replay. The wrapper does not solve this itself, on purpose. Its docstring is explicit that it stores no persistent buffers and copies nothing; it assumes the caller feeds it tensors that already live in stable buffers. What it does do, under debug logging, is verify the assumption:

# vllm/compilation/cuda_graph.py
new_input_addresses = [
    x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
    f"Input addresses for cudagraphs are different "
    f"during replay. Expected {entry.input_addresses}, "
    f"got {new_input_addresses}"
)

Source: vllm/compilation/cuda_graph.py

If that assert ever fires, someone allocated a fresh tensor where a persistent buffer was expected, and the graph would have silently read garbage.

The stable buffers are supplied by the model runner. In gpu_model_runner.py the runner allocates persistent input tensors once, sized to the maximum, and copies each step’s data into them rather than allocating anew:

# vllm/v1/worker/gpu_model_runner.py
# Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = torch.zeros(
    self.max_num_tokens, dtype=torch.int64, device=self.device
)

This is the resolution of the address problem. Because self.input_ids is allocated exactly once and reused, its device address never changes. Every step writes the real token ids into the first slots of that one buffer (overwriting last step’s contents in place) and lets the captured graph read from that fixed address. The two invariants a graph needs are now both satisfied by construction: the dispatcher’s padding (next section) guarantees the shape the graph sees is one it captured, and the persistent buffer guarantees the address is the one it recorded. Shape from padding, address from the persistent buffer; that pair is the whole contract.

The output side is equally careful about lifetime. The wrapper always stores the cached output as a weak reference (weak_ref_tensors) so PyTorch’s graph pool, not Python, owns that memory. The weak_ref_output option then handles a subtlety of the piecewise chain: an intermediate subgraph’s output is fed straight into the next subgraph’s capture and must stay alive, but the last subgraph’s output “will not be used by any other cuda graph,” so it is the only one safe to release immediately inside the capture context (weak_ref_output=is_last_graph). Getting this wrong leaks the entire graph pool. The file’s inline comment, “mind-exploding: carefully manage the reference and memory,” is not hyperbole.

Dispatch: which graph, padded to what

At runtime the runner does not decide capture-versus-replay itself; it asks the CudagraphDispatcher to map the current batch onto a captured graph. Think of the dispatcher as a lookup table from “what does this batch look like” to “which recording, if any, fits it.” The key is a BatchDescriptor — the dataclass that captures everything a graph is specialized on (padded token count, request count, whether it is a uniform decode, LoRA state). The dispatcher holds two sets of these keys, one per runtime mode (FULL and PIECEWISE), because the two modes specialize on different things. Dispatch pads the batch up to the nearest captured size and tries the keys in priority order:

# vllm/v1/cudagraph_dispatcher.py
if CUDAGraphMode.FULL in allowed_modes:
    batch_desc_to_check = batch_desc
    if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]:
        return CUDAGraphMode.FULL, batch_desc_to_check

if CUDAGraphMode.PIECEWISE in allowed_modes:
    batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
    if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
        return CUDAGraphMode.PIECEWISE, batch_desc_to_check

return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

Source: vllm/v1/cudagraph_dispatcher.py

The diagram below traces this three-way decision, and the three things to read out of it follow.

flowchart TD
    Start["incoming batch, padded to nearest captured size"] --> F{"FULL key matches?"}
    F -->|yes| UseF["replay FULL graph: whole step including attention, one launch"]
    F -->|no| P{"PIECEWISE key matches? (relax num_reqs and uniform)"}
    P -->|yes| UseP["replay PIECEWISE graphs: dense parts captured, attention eager"]
    P -->|no| None["mode NONE: run fully eager, launch every kernel by hand"]

First, FULL is preferred when available. A FULL graph captures the entire step including attention, so it eliminates even the eager attention launches that PIECEWISE leaves in the gaps. This is the FULL_AND_PIECEWISE default: decode batches get a full graph (Chapter 9’s flash-attention backends support full-graph capture for uniform decode, where the attention shape is regular enough to freeze) while prefill or mixed batches, whose attention is too irregular, fall back to piecewise. Second, the piecewise lookup relaxes the key before checking it — it nulls out num_reqs and uniform — because a piecewise graph, with attention living outside the recording, genuinely does not care how the tokens are partitioned into requests; only the token count matters to the dense segments. The full graph cannot relax those fields, because its in-graph attention kernel reads metadata (block tables, sequence boundaries) that depends on exactly how many requests there are and whether they are a uniform decode. The relaxation is why a single piecewise recording covers many request-count variations that would each need their own full graph. Third, if neither mode has a matching key (the batch exceeds the largest captured size, or a feature like cascade attention is incompatible), dispatch returns NONE and the step runs fully eager, paying the per-kernel launch overhead this whole chapter is trying to kill. That last fallback is exactly the “cudagraph fallback” regression the observability chapter teaches you to spot: a few oversized batches quietly costing you the launch overhead you thought you had eliminated.

The padding is not free. Padding a batch of 47 up to a captured size of 48 means running one sequence’s worth of phantom work, but a batch of 33 pads all the way to 48, wasting a third of the step. The waste is a sawtooth in the real batch size: it spikes just above each capture size (the worst case in the set below is a batch of 9, which pads to 16 and wastes nearly half the step) and falls to zero right at a captured size. The curve below plots that wasted fraction against the real batch size for a representative set of capture sizes; the teeth get shallower toward the large captured sizes because the gaps between them shrink in relative terms.

Illustrative: the sawtooth shape and the spike-above-each-capture-size pattern are exact for the listed capture sizes; vLLM’s actual capture-size set and cap are tunable, so the real teeth differ.

vLLM bounds this with a sensible set of capture sizes and a capped maximum, and the dispatcher precomputes a dense bs_to_padded_graph_size table so the padding decision is a constant-time lookup. The waste is real but small relative to the launch overhead it eliminates, and the capture sizes are a tunable knob.

The other idle gap: scheduling

CUDA graphs kill the launch overhead inside a step. They do nothing for the overhead between steps. Even with a perfectly graphed step, there is a seam: after the GPU finishes step N, the host has to copy the sampled token ids back, run the scheduler’s update_from_output, decide the next batch, and build the input tensors, and only then can it launch step N+1. The synchronous loop alternates strictly between the two, so the GPU sits idle during every one of those host phases. The top half of the diagram below shows that idle gap; the bottom half shows the fix.

sequenceDiagram
    participant CPU
    participant GPU
    Note over CPU,GPU: synchronous: GPU idles while CPU schedules
    CPU->>GPU: launch step N
    GPU-->>CPU: step N done
    Note over GPU: idle
    CPU->>CPU: schedule step N+1
    CPU->>GPU: launch step N+1
    Note over CPU,GPU: async: schedule N+1 while GPU runs N
    CPU->>GPU: launch step N
    CPU->>CPU: schedule step N+1 (overlaps GPU)
    CPU->>GPU: launch step N+1
    GPU-->>CPU: step N done

The fix is the classic one: pipeline it. Run the CPU scheduling of step N+1 concurrently with the GPU execution of step N, so that the moment the GPU finishes step N the next batch is already queued and waiting. vLLM calls this async scheduling, and it is gated by a single scheduler-config flag:

# vllm/config/scheduler.py
async_scheduling: bool | None = None
"""If set to False, disable async scheduling. Async scheduling helps to
avoid gaps in GPU utilization, leading to better latency and throughput.
"""

Source: vllm/config/scheduler.py

When it is on, the engine runs step_with_batch_queue instead of the plain step. The structural change is that execute_model returns a future instead of blocking for the GPU result, and that future goes into a queue; the engine then prefers to schedule and launch the next batch rather than immediately wait on the previous one. It only blocks on a future when the queue is full or there is nothing left to schedule, which is what keeps the GPU continuously fed:

# vllm/v1/engine/core.py
if not deferred_scheduler_output:
    batch_queue.appendleft((future, scheduler_output, exec_future))
    if len(batch_queue) < self.batch_queue_size and (
        model_executed or self.scheduler.has_requests()
    ):
        # Don't block on next worker response unless the queue is full
        # or there are no more requests to schedule.
        return None, model_executed

Source: vllm/v1/engine/core.py

The queue size is max_concurrent_batches, which is where async scheduling and pipeline parallelism meet. The same machinery that overlaps scheduling with execution is what keeps a pipeline-parallel deployment’s stages full:

# vllm/config/vllm.py
def max_concurrent_batches(self) -> int:
    # PP requires PP-size concurrent batches to fill the pipeline.
    # Async scheduling requires 2 concurrent batches to overlap.
    pp_size = self.parallel_config.pipeline_parallel_size
    if self.scheduler_config.async_scheduling:
        if self.use_v2_model_runner:
            return pp_size + 1
        if pp_size <= 1:
            return 2
    return pp_size

Source: vllm/config/vllm.py

Async scheduling on a single replica wants two batches in flight; PP wants pp_size; the v2 runner combining both wants pp_size + 1. It is the same idea you reach for in any producer/consumer pipeline: queue just enough work to hide the other side’s latency without unbounded buffering.

The catch: scheduling the next step before you know the last token

There is a genuine problem buried in that overlap. To schedule step N+1, the scheduler needs each sequence’s next input token — which is the one sampled by step N, still running on the GPU. Waiting to read it back would kill the overlap. So async scheduling schedules step N+1 without yet knowing the sampled tokens, using placeholders.

The AsyncScheduler advances each running request’s bookkeeping optimistically, recording how many tokens are “in flight” as placeholders rather than concrete ids:

# vllm/v1/core/sched/async_scheduler.py
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new draft/spec tokens.
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = self._spec_token_placeholders

Source: vllm/v1/core/sched/async_scheduler.py

The genuinely clever part is in the worker, and it is what lets the placeholders work without ever stalling. Recall the chicken-and-egg problem: step N+1’s input is the token sampled by step N, but that token is still being computed on the GPU. The naive fix would be to copy the token back to the host and feed it into the next step’s input tensor, but that device-to-host copy is exactly the synchronization that would destroy the overlap. The insight is that the token never needs to leave the GPU to become the next step’s input, because both the sampled-token tensor and the persistent input buffer already live on the device. The model runner caches the previous step’s sampled-token tensor (prev_sampled_token_ids) and, when preparing the next batch’s input_ids, scatters those device-side tokens straight into the persistent input buffer with a GPU kernel:

# vllm/v1/worker/gpu_model_runner.py
self.input_ids.gpu.scatter_(
    dim=0,
    index=sampled_tokens_index_tensor,
    src=self.input_batch.prev_sampled_token_ids[
        prev_common_req_indices_tensor, 0
    ],
)

Source: vllm/v1/worker/gpu_model_runner.py

The host never learns the token id before launching the next step; the handoff from step N to step N+1 happens entirely on-device via a GPU scatter, never touching the CPU on the critical path. The sequence diagram below contrasts the two handoff paths: the on-device path (solid, on the critical path) carries the token from step N straight into step N+1, while the host-bound copy (the background detokenize/stream path) runs off to the side.

sequenceDiagram
    participant CPU
    participant GPU
    CPU->>GPU: launch step N
    GPU->>GPU: sample token, store in prev_sampled_token_ids (on device)
    GPU->>GPU: scatter prev tokens into persistent input_ids buffer
    CPU->>GPU: launch step N+1 (reads tokens already on device)
    GPU-->>CPU: background copy of tokens (separate stream, for detokenize and stop checks)

There is even a common-case fast path: if the decode batch is unchanged and nothing was reordered, the whole thing collapses to one contiguous slice copy:

# vllm/v1/worker/gpu_model_runner.py
if common_indices_match and max_flattened_index == (num_common_tokens - 1):
    self.input_ids.gpu[:num_common_tokens].copy_(
        self.input_batch.prev_sampled_token_ids[:num_common_tokens, 0],
        non_blocking=True,
    )
    return

Source: vllm/v1/worker/gpu_model_runner.py

The device-to-host copy of the sampled tokens still happens — the host needs them eventually to detokenize, stream (Chapter 11), and check stop conditions — but it is moved off the critical path. Instead of blocking the scheduling loop, it runs on a separate CUDA stream, so the host can keep launching steps while the copy proceeds in the background, and the result is only awaited when something actually needs the token ids on the CPU. AsyncGPUModelRunnerOutput kicks it off non-blocking and only synchronizes when the output is actually consumed:

# vllm/v1/worker/gpu_model_runner.py
with torch.cuda.stream(async_output_copy_stream):
    async_output_copy_stream.wait_stream(default_stream)
    self.sampled_token_ids_cpu = self._sampled_token_ids.to(
        "cpu", non_blocking=True
    )
    ...
    self.async_copy_ready_event.record()

Source: vllm/v1/worker/gpu_model_runner.py

Where the optimism leaks

The placeholders are an optimistic bet: the scheduler assumes the steps it launched speculatively will all turn out to be wanted. Optimism has costs, paid in the corner cases where that bet is wrong and the engine has already committed work it now needs to unwind. The cleanest example: when something forces a running request out from under in-flight work. If the prefix cache is reset and every running request is force-preempted, any output frames the engine already launched for those requests are now stale — they computed tokens for a request that no longer exists in its old form — and must be dropped when they return from the GPU. The scheduler records exactly how many such frames are outstanding — the request’s placeholder count at preemption time — and the async scheduler drains them one per call:

# vllm/v1/core/sched/async_scheduler.py
if request.async_tokens_to_discard > 0:
    # The request was force-preempted in reset_prefix_cache; drop one
    # stale in-flight async output frame per call until the counter
    # is drained.
    request.async_tokens_to_discard -= 1
    return [], False

Source: vllm/v1/core/sched/async_scheduler.py

The counter is set in reset_prefix_cache itself, where request.async_tokens_to_discard = request.num_output_placeholders captures the in-flight count before zeroing the placeholders. The everyday case — a sampled token that turns out to be a stop token, after step N+1 was already scheduled for that sequence — is handled more quietly: the over-produced tokens are simply truncated when the stopped request’s output is reconciled, the same num_output_placeholders bookkeeping unwinding by the number of tokens that actually came back.

Speculative decoding (Chapter 13) makes this harder, because the number of tokens accepted per step is not known until verification runs on the GPU. That is why the placeholders above include 1 + cur_num_spec_tokens and why the engine’s post_step skips its usual draft-token bookkeeping under async scheduling — the draft ids are resolved in the worker, not the engine:

# vllm/v1/engine/core.py
if not self.async_scheduling and self.use_spec_decode and model_executed:
    draft_token_ids = self.model_executor.take_draft_token_ids()
    if draft_token_ids is not None:
        self.scheduler.update_draft_token_ids(draft_token_ids)

Source: vllm/v1/engine/core.py

And structured output (Chapter 14) is the hardest interaction of all, because the grammar bitmask for step N+1 depends on which token was actually sampled at step N. When a request is constrained by a grammar, the engine cannot blindly sample the next step; it must defer sampling until the prior output is processed, which is the deferred_scheduler_output branch in step_with_batch_queue above. Async scheduling and constrained decoding pull in opposite directions, and the code threads the needle by giving up the overlap precisely for the requests that need the prior token to compute their mask.

Honest limits

These techniques are mature and on by default, but not magic. CUDA graphs cost capture time at startup (one capture per shape), cost memory (the graph pool plus padding waste), and fall back to eager whenever a batch does not match a captured descriptor — the larger your max_num_seqs relative to your capture sizes, the more often that happens. Full-graph capture depends on the attention backend supporting it, so a Chapter 9 backend choice quietly decides whether you get FULL or only PIECEWISE. Async scheduling buys throughput at the cost of a more intricate state machine whose edge cases (stops, preemption, spec-decode rejection, grammar masks) are where subtle correctness bugs live, and it does not compose freely: the v1 runner does not fully support async scheduling with PP, which is why max_concurrent_batches special-cases it.

Both techniques share a theme that runs through the engine: the win comes from refusing to let the GPU wait on the host. We removed the launch wait by recording it and the scheduling wait by overlapping it. The one host round-trip we have not removed is the one that turns sampled token ids into streamed text — the device-to-host copy AsyncGPUModelRunnerOutput defers but cannot eliminate. That copy, the sampler pipeline that produces it, and the surprisingly racy path from token ids back to an SSE stream are the next chapter.

Further reading

This chapter’s mechanisms are engineering rather than research, so there is no single paper to cite; the techniques are best learned from the vLLM source quoted above and from the upstream documentation for torch.compile and torch.cuda.CUDAGraph. For background on the kernel-launch-overhead problem that motivates CUDA graphs, the roofline framing from Chapter 3 (Data Movement Is All You Need, arXiv:2007.00072) and Chapter 9’s attention-kernel discussion are the relevant prior reading within this book.

The sampler and the egress path: logits to streamed text

Everything in Part II so far has been about getting the forward pass to happen efficiently: paging the KV cache, budgeting the batch, slicing prefills, reusing prefixes, dispatching the right attention kernel, and finally (in Chapter 10) replaying the whole thing as a CUDA graph while the next step’s scheduling overlaps on the CPU. But a forward pass produces a tensor of hidden states, and a hidden-state tensor is not an HTTP response. This chapter follows the last stretch of the pipeline: from the model’s output projection, through the sampler, across the process boundary back to the API server, and out as Server-Sent Events.

It is a deceptively boring-sounding stretch, which is exactly why it’s worth a chapter. Two things hide here that people get wrong. First, the sampler is a small GPU pipeline with exactly one mandatory GPU-to-CPU synchronization, and async scheduling from Chapter 10 turns even that into something you can skip on the hot path. Second, the path from sampled token ids to a correct SSE stream contains a distributed-cancellation race between two processes, and getting it wrong means either leaked GPU work or truncated outputs.

Before diving into either, it helps to have the whole route in view. The diagram below traces a single token from the moment the forward pass finishes to the moment its text leaves the server. Notice that the journey crosses a process boundary in the middle: the left half lives in the engine process (on and near the GPU), the right half lives in the frontend process (the HTTP server). That boundary is where the second problem hides.

flowchart LR
    subgraph engine["engine process (on and near the GPU)"]
        A["hidden states (GPU)"] --> B["gather last position per request"]
        B --> C["output projection to vocab logits"]
        C --> D["sampler: penalties, temperature, top-k/top-p, draw"]
        D --> E["sampled token ids (still on GPU)"]
        E --> F["the one D2H sync: copy ids to CPU"]
    end
    F -->|"ZMQ socket"| G
    subgraph frontend["frontend process (HTTP server)"]
        G["receive EngineCoreOutput"] --> H["incremental detokenizer: ids to text"]
        H --> I["stop-string check and abort handshake"]
        I --> J["SSE frame: data ... then DONE"]
    end

Only the last position matters

Recall from Chapter 1 that decode appends one token at a time, and from Chapter 5 that a continuously-batched step is a flat bag of tokens with no prefill/decode distinction. A subtle consequence: most of the positions in that flat bag produce hidden states you will immediately throw away. During prefill, the model computes a hidden state for every prompt token, but you only ever sample from the last one. During a chunked prefill (Chapter 6) you don’t even want that, the partial request’s last position isn’t the real end of the prompt yet.

So before the engine pays for the (large) output projection from hidden dimension to vocabulary size, it gathers only the rows it actually needs. The “output projection” is the final linear layer that maps each hidden state (a vector of size hidden_dim, a few thousand) to a logit vector of size vocab_size (often 128k or more). One logit per vocabulary entry, per position. Running that projection on every position in the flat batch would produce a giant [num_tokens, vocab] tensor, almost all of which is discarded. So the runner picks the keepers first. In the model runner:

# vllm/v1/worker/gpu_model_runner.py
logits_indices = query_start_loc[1:] - 1

Source: vllm/v1/worker/gpu_model_runner.py

To unpack that one line: query_start_loc is the cumulative-sum boundary array for the batch (the same ragged-sequence bookkeeping the attention kernel uses in Chapter 9). If three requests contribute 5, 3, and 8 tokens to the flat batch, query_start_loc is [0, 5, 8, 16]: each entry is where a request starts. Dropping the first entry with [1:] gives [5, 8, 16], the position just past the end of each request, and subtracting 1 gives [4, 7, 15], the last token of each request. So logits_indices is, in plain terms, “the last position of each request.” Those indices select hidden states, and only then is the vocabulary projection applied:

# vllm/v1/worker/gpu_model_runner.py
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)

Source: vllm/v1/worker/gpu_model_runner.py

This is why Chapter 3 could say prefill “produces no logits worth keeping.” The compute-bound prefill phase exists to fill the KV cache; the only logit it yields is the one that seeds decode. Gathering first also keeps the expensive [num_sampled, vocab] logits tensor as small as the batch allows.

The sampler as an ordered pipeline

The sampler turns a logit vector into a single chosen token id. A logit is an unnormalized score for each vocabulary entry; a softmax over the logits would turn them into a probability distribution, and “sampling” means drawing one token from that distribution. Between the raw logits and the draw sit a stack of transforms that reshape the distribution: penalties that discourage repetition, masks that forbid certain tokens, temperature that flattens or sharpens the curve, and top-k/top-p that lop off the unlikely tail. The sampler is an nn.Module, and its docstring is unusually candid about being a fixed sequence of those stages. The class comment in vllm/v1/sample/sampler.py lays out the order explicitly: compute logprobs if requested, cast to float32, apply allowed-token and bad-word masks, apply the non-argmax-invariant logit processors and penalties, then sample, which itself temperature-scales, applies argmax-invariant processors (min-p), applies top-k/top-p, and draws.

That order is the crux, so it is worth naming the organizing principle: every transform is classified by whether it can move the argmax (the single highest-scoring token). A transform is “argmax-invariant” if it can never change which token is on top, and “non-argmax-invariant” if it can. This matters because a greedy request (temperature zero) simply takes the argmax, so any argmax-invariant transform is wasted work for it. The pipeline below puts all the argmax-changing transforms first, outside sample(), and tucks the argmax-preserving ones inside sample() where a greedy request can skip them entirely. The diagram traces both paths through the sampler.

flowchart TD
    L["logits (gathered, one row per request)"] --> F32["cast to float32"]
    F32 --> M["apply allowed-token and bad-word masks"]
    M --> P["apply penalties and logit bias (can move the argmax)"]
    P --> Q{"all greedy?"}
    Q -->|"yes"| G["greedy_sample: just take the argmax"]
    Q -->|"no"| S["sample(): temperature, min-p, top-k/top-p, draw"]
    S --> W{"row is greedy?"}
    W -->|"yes (mixed batch)"| GR["use that row's argmax"]
    W -->|"no"| RD["use that row's random draw"]
    G --> OUT["sampled token ids (GPU tensors)"]
    GR --> OUT
    RD --> OUT

Order is not arbitrary here, and the code is careful about which transforms can change a greedy result. Penalties and bias and masks come first because they shift the argmax. Temperature, min-p, and top-k/top-p come inside sample() because for a greedy request they are irrelevant, so the runner can short-circuit:

# vllm/v1/sample/sampler.py
if sampling_metadata.all_random:
    greedy_sampled = None
else:
    greedy_sampled = self.greedy_sample(logits)
    if sampling_metadata.all_greedy:
        ...
        return greedy_sampled, processed_logprobs

Source: vllm/v1/sample/sampler.py

A pure-greedy batch never sorts a single logit; it returns from greedy_sample before sample() is even called. The interesting case is a mixed batch, where some requests in the same step want greedy decoding and others want random sampling. The GPU does not branch per row cheaply, so vLLM computes both answers for the whole batch (the greedy argmax for every row and the random draw for every row) and then selects per row with a torch.where keyed on whether each request’s temperature is below epsilon. The greedy rows pay for a random draw they throw away, and the random rows pay for an argmax they throw away, but both run as dense vectorized kernels with no data-dependent control flow, which on a GPU is far cheaper than trying to handle each row separately. Note the small honesty in apply_temperature: it rewrites a zero temperature to 1.0 before dividing, purely to avoid a divide-by-zero for the greedy rows whose result will be discarded anyway.

# vllm/v1/sample/sampler.py
@staticmethod
def apply_temperature(logits, temp, all_random):
    if not all_random:
        temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
    return logits.div_(temp.unsqueeze(dim=1))

Source: vllm/v1/sample/sampler.py

The whole thing runs in float32 (logits = logits.to(torch.float32)) even when the model ran in bf16, because the softmax and the cumulative sums that top-p needs are numerically nasty in low precision and this tensor is small enough that the cast is cheap.

Temperature is the one knob that literally reshapes the softmax curve, and the effect is hard to picture from the formula alone: dividing every logit by $T$ before the softmax sharpens the distribution toward the argmax when $T<1$ and flattens it toward uniform when $T>1$. The curves below take a fixed set of eight logits and apply three temperatures. At $T=0.5$ the top token carries over 80% of the mass (a near-greedy draw); at $T=1$ the model’s native distribution shows through; at $T=2$ the mass spreads out and the long tail becomes far more likely to be sampled.

Illustrative: probabilities computed from a fixed eight-entry logit vector [4, 3, 2.5, 2, 1.5, 1, 0.5, 0]; the shape is exact for these logits but the logits themselves are chosen for clarity, not measured.

Where top-p came from

Top-p, or nucleus sampling, has a clean provenance. The Curious Case of Neural Text Degeneration (arXiv:1904.09751) observed that greedy and beam search on large LMs produce degenerate, repetitive text, and that naive top-k truncation either leaves in garbage or cuts off the natural variety of the distribution. Their fix was to sample from the smallest set of tokens whose cumulative probability exceeds a threshold $p$, a set whose size adapts to how peaked the distribution is. That adaptive nucleus is exactly what the cumulative-sum mask in vLLM implements; the paper is the fastest way to understand why top-p is the production default rather than top-k.

Sorting is the enemy

The textbook way to do top-p is to sort the vocabulary, take a cumulative sum, and mask the tail. vLLM’s native path does precisely that:

# vllm/v1/sample/ops/topk_topp_sampler.py
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
...
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)

Source: vllm/v1/sample/ops/topk_topp_sampler.py

A full sort over a 128k-or-larger vocabulary, for every request, every step, is not free. So on CUDA, vLLM prefers FlashInfer’s sorting-free sampler, which uses a rejection-sampling scheme to draw from the truncated distribution without ever materializing a sorted order:

# vllm/v1/sample/ops/topk_topp_sampler.py
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
    logits, k, p, deterministic=True
)

Source: vllm/v1/sample/ops/topk_topp_sampler.py

The honesty here is in the docstring: this is statistically equivalent to the sorting path, not bit-identical to it. The dispatch is decided once at construction in TopKTopPSampler.__init__, which binds self.forward to forward_cuda only when flashinfer_sampler_supported() and the logprobs mode doesn’t need post-filter logits (FlashInfer doesn’t expose them). Even forward_cuda falls back to the native path when there’s nothing to filter or when per-request RNG generators are present, which FlashInfer 0.2.3+ can’t honor. This is a recurring shape in vLLM: a fast kernel for the common case, guarded by a wall of correctness conditions that quietly route the awkward cases to the slow, simple path.

One more detail that connects to randomness done right. The obvious way to draw a token from a probability distribution is torch.multinomial, but it causes a CPU-GPU sync, which would defeat everything Chapter 10 set up. So random_sample uses the Gumbel-max trick instead. The trick is a small identity: for a categorical distribution with probabilities $p_i$, if you perturb each log-probability with independent Gumbel noise $g_i$ and then take the argmax, the token you land on is distributed exactly as if you had sampled from the original distribution:

$$\arg\max_i \left( \log p_i + g_i \right) \sim p$$

Equivalently, dividing each probability by an independent unit-exponential sample $q_i$ and taking $\arg\max_i (p_i / q_i)$ gives the same draw. The payoff is that an argmax is a plain reduction with no host round-trip, so the entire draw stays on the GPU. vLLM draws the exponential noise and takes that argmax:

# vllm/v1/sample/ops/topk_topp_sampler.py
def random_sample(probs, generators, use_fp64_gumbel=False):
    """We use this function instead of torch.multinomial because
    torch.multinomial causes CPU-GPU synchronization."""
    q = empty_exponential_noise_like(probs, use_fp64_gumbel)
    if len(generators) != probs.shape[0]:
        q.exponential_()
    ...
    return sample_with_exponential_noise(probs, q)

Source: vllm/v1/sample/ops/topk_topp_sampler.py

The one mandatory sync

After all of that, the sampler returns GPU tensors. The comment in sampler.py is blunt: “These are GPU tensors.” The sampled token ids live in device memory. But the scheduler, the detokenizer, and the API server all run on the CPU, in (mostly) a different process. At some point the ids have to come down. That copy, a device-to-host transfer (D2H), is the single mandatory GPU-to-CPU synchronization of the entire step. The reason it is a synchronization and not just a copy is worth pinning down: GPU work is normally launched asynchronously, with the CPU queueing kernels and racing ahead without waiting for results. But to read an actual value the GPU computed, such as which token it sampled, the CPU has no choice but to stop and wait until that value is finished and copied down. Everything else, the whole forward pass, can stay asynchronous; this is the one place the CPU must learn a value the GPU computed.

The runner does this copy through a pinned host buffer for speed, and it is deliberate about how it waits:

# vllm/v1/worker/gpu_model_runner.py
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
    # `tolist` would trigger a cuda wise stream sync, which
    # would block other copy ops from other cuda streams.
    # A cuda event sync would avoid such a situation.
    pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]]
    pinned.copy_(sampled_token_ids, non_blocking=True)
    self.transfer_event.record()
    self.transfer_event.synchronize()
    return pinned.tolist()

Source: vllm/v1/worker/gpu_model_runner.py

The subtlety is the difference between two ways of waiting. A naive tensor.tolist() triggers a device-wide stream synchronize: it blocks until every outstanding operation on the device is done, which also stalls unrelated copy streams that happen to be running other work (the comment links a real regression that hurt a disaggregated setup, foreshadowing Chapter 17). A CUDA event is narrower: transfer_event.record() drops a marker into the stream right after this one copy, and transfer_event.synchronize() waits only until execution reaches that marker. So the CPU blocks for exactly this copy and nothing else. This is exactly the kind of micro-cost that Chapter 10 cared about: when the GPU finishes a decode step in under a millisecond, a stray device-wide sync is a measurable tax.

Skipping even the sync: the async-scheduling shortcut

Chapter 10 introduced async scheduling, where the CPU plans step N+1 while the GPU runs step N. There is a chicken-and-egg problem buried in it: step N+1 needs the token sampled in step N as its input. If the CPU has to wait for that token to come down before it can build the next batch, the overlap collapses. vLLM’s answer is to not bring the token down at all on the hot path. In _bookkeeping_sync, the async branch keeps the sampled ids on the GPU and remembers their layout:

# vllm/v1/worker/gpu_model_runner.py
else:
    valid_sampled_token_ids = []
    invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
    ...
    # Cache the sampled tokens on the GPU and avoid CPU sync.
    if self.input_batch.prev_sampled_token_ids is None:
        assert sampled_token_ids.shape[-1] == 1
        self.input_batch.prev_sampled_token_ids = sampled_token_ids

Source: vllm/v1/worker/gpu_model_runner.py

The next step’s input preparation copies prev_sampled_token_ids directly into the input buffer on-device, so the token round-trips GPU-to-GPU and never touches the CPU until it’s needed for output. Notice it builds an invalid_req_indices list instead of clearing tokens immediately: for partial (chunked-prefill) requests we sampled a junk token “for simplicity,” and those indices get blanked later. The host copy itself is deferred into AsyncGPUModelRunnerOutput, which fires the D2H on a separate copy stream and only blocks when someone calls get_output():

# vllm/v1/worker/gpu_model_runner.py
with torch.cuda.stream(async_output_copy_stream):
    async_output_copy_stream.wait_stream(default_stream)
    self.sampled_token_ids_cpu = self._sampled_token_ids.to(
        "cpu", non_blocking=True)
    ...
    self.async_copy_ready_event.record()

Source: vllm/v1/worker/gpu_model_runner.py

So the one mandatory sync is real, but it’s been pushed as far down the pipeline as it can go, off the critical path that gates the next forward.

Across the process boundary

Here the architecture from Chapter 1 reasserts itself. vLLM splits the frontend (HTTP, tokenization, detokenization, output processing) from the EngineCore (scheduler plus model runner) into separate processes, wired together over ZMQ. The core_client.py header states the roster plainly:

# vllm/v1/engine/core_client.py
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)

This split exists so that Python’s GIL on the busy HTTP server doesn’t steal cycles from the engine loop, and vice versa. The engine emits batches of EngineCoreOutput (token ids, finish reasons, logprob tensors already pulled to CPU) over the socket. On the frontend side, a single background task drains them. In AsyncLLM._run_output_handler, that loop pulls, chunks (so it never hogs the event loop), processes, and crucially handles aborts:

# vllm/v1/engine/async_llm.py
processed_outputs = output_processor.process_outputs(
    outputs_slice, outputs.timestamp, iteration_stats)
...
# 3) Abort any reqs that finished due to stop strings.
if processed_outputs.reqs_to_abort:
    await engine_core.abort_requests_async(
        processed_outputs.reqs_to_abort)

Source: vllm/v1/engine/async_llm.py

That last line is the seam where the cancellation race lives. Hold onto it.

Turning ids back into text, incrementally

The detokenizer’s job is to turn token ids back into the characters a user will read. The naive approach would be to call tokenizer.decode(all_ids) from scratch on every step, but that is both wasteful (re-decoding the entire prefix each step) and, more importantly, wrong at the boundaries. BPE and SentencePiece tokenizers are stateful at the byte level: a single emoji or CJK character spans multiple tokens, so the bytes of one character can arrive split across two decode steps, and a half-character is not valid text. Decoding a prefix can also produce a different string than decoding the whole thing (the classic leading-space problem, where a token’s rendering depends on whether it follows whitespace, and the partial-UTF-8 problem just described). So vLLM detokenizes incrementally, holding a streaming decoder per request that remembers the in-progress byte state from one token to the next and only emits characters once they are complete. The fast path, for tokenizers >= 0.22.0, primes a native DecodeStream with the prompt and steps it one token at a time:

# vllm/v1/engine/detokenizer.py
self.stream = tokenizers.decoders.DecodeStream(
    ids=request.prompt_token_ids,
    skip_special_tokens=self.skip_special_tokens,
)

Source: vllm/v1/engine/detokenizer.py

Each new token is fed through _protected_step, which exists entirely to survive the rough edges of real tokenizers: it catches overflow on bad ids and, on an “Invalid prefix encountered” error from a non-monotonic UTF-8 sequence, resets the stream rather than crashing the request. Both branches cite real issues. Incremental detokenization looks trivial and is full of one-off bugs, which is why the code reads defensively.

The stop-string buffer

A stop string is a piece of text that, when generated, should end the request. Stop strings are where detokenization stops being a pure decode and starts being a control-flow decision. Two facts make them tricky. First, a user-supplied stop string like "\n\n" might straddle a token boundary, so it can only be detected after decoding, in the character stream, not in the token-id stream. Second, and worse, if the stop string is not to be included in the output, you cannot stream the last few characters until you are sure they are not the beginning of a stop match. Imagine the stop string is "END" and you have just decoded "...the EN": you must not send those last two characters yet, because the next token might complete "END", in which case they need to vanish from the output. So the detokenizer holds back a buffer of $\max_s |s| - 1$ characters, where $s$ ranges over the configured stop strings and $|s|$ is the length of stop string $s$:

# vllm/v1/engine/detokenizer.py
if self.stop and not self.include_stop_str_in_output:
    self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
    self.stop_buffer_length = 0

Source: vllm/v1/engine/detokenizer.py

get_next_output_text then trims that many trailing characters from each streamed delta (revealing the held-back tail only once the request is finished). And check_stop_strings searches only the newly-added characters plus enough lookback to catch a match that spans the previous chunk, then returns where to truncate. The point to internalize: a stop string is detected on the frontend, by inspecting decoded text, not by the engine inspecting token ids. That asymmetry is the whole reason the next section’s race exists.

The abort-on-stop race

Put the pieces together. The EngineCore schedules and samples; it does not detokenize, so it cannot see a stop string (it only knows about stop token ids and length limits, which it can enforce itself). The frontend detokenizes and therefore is the only party that can detect a stop string. When it does, the engine is still happily generating that request, one step ahead, wasting GPU on tokens nobody will read. The sequence diagram below traces exactly this: the engine keeps stepping while the abort message is still in flight, and the frontend has to discard the outputs that arrive in the gap.

sequenceDiagram
    participant E as "EngineCore (engine process)"
    participant F as "Output processor (frontend)"
    E->>F: "step N output: token ids"
    F->>F: "detokenize, no stop string yet"
    E->>F: "step N+1 output: completes the stop string"
    F->>F: "detect stop string, finish request locally"
    F->>E: "abort_requests_async(req_id)"
    Note over E: "engine already sampled step N+2 before abort arrived"
    E->>F: "step N+2 output: stale token ids"
    F->>F: "request_states.get(req_id) is None, drop it"
    E->>E: "free request, stop scheduling it"

So the output processor, on detecting a stop string, both finishes the request locally and signals that the engine must be told to abort it:

# vllm/v1/engine/output_processor.py
stop_string = req_state.detokenizer.update(
    new_token_ids, finish_reason == FinishReason.STOP)
if stop_string:
    finish_reason = FinishReason.STOP
    stop_reason = stop_string
...
if not engine_core_output.finished:
    # If req not finished in EngineCore, but Detokenizer
    # detected stop string, abort needed in EngineCore.
    reqs_to_abort.append(req_id)

Source: vllm/v1/engine/output_processor.py

That reqs_to_abort list is what flowed back to the abort_requests_async call in the output handler. The race is the window between “frontend decides to stop” and “engine actually frees the request.” In that window the engine may already have sampled the next token (or several, under speculation) for a request the frontend considers done. The frontend must therefore ignore late-arriving outputs for requests it has already finished, which is exactly what the top of process_outputs does:

# vllm/v1/engine/output_processor.py
req_state = self.request_states.get(req_id)
if req_state is None:
    # Ignore output for already-aborted request.
    continue

Source: vllm/v1/engine/output_processor.py

The other direction of the same race is client disconnect. When an HTTP client hangs up, the generate() async generator is cancelled or garbage-collected, and vLLM must abort the in-flight request or leak GPU work indefinitely:

# vllm/v1/engine/async_llm.py
except (asyncio.CancelledError, GeneratorExit):
    if q is not None:
        await self.abort(q.request_id, internal=True)

Source: vllm/v1/engine/async_llm.py

Both directions resolve to the same primitive, an abort message to the engine, and the same defensive rule, drop outputs whose request state is already gone. The reason most people get this wrong is that a single-process mental model hides it: if detokenization and generation shared a loop, you’d stop the moment you saw the stop string. Across a process boundary, with one side a step ahead, stopping is a two-phase handshake with a guaranteed window of wasted work and a guaranteed stream of stale outputs to discard.

Out the door as SSE

The final hop is almost anticlimactic. The OpenAI-compatible server iterates the RequestOutputs yielded by generate() and serializes each into a Server-Sent Events frame. The streaming generator’s shape is exactly what you’d expect from a FastAPI endpoint:

# vllm/entrypoints/openai/chat_completion/serving.py
async for res in result_generator:
    ...
    yield f"data: {data}\n\n"
...
yield "data: [DONE]\n\n"

Source: vllm/entrypoints/openai/chat_completion/serving.py

The data: ...\n\n framing and the terminal [DONE] sentinel are the SSE wire format the OpenAI client libraries expect. The stream_interval and DELTA-mode logic in RequestState.make_request_output decide how often to emit and whether to send full text or just the new delta, trading client chattiness against per-token latency visibility. But by the time text reaches this loop, every hard problem is already solved upstream: the sampling distribution, the one D2H sync, the incremental decode, the stop handshake.

What’s unsolved, and what’s next

The egress path is mature but not finished. Detokenization is single-threaded per request and runs in the frontend process; under very high concurrency it can become a CPU bottleneck that the chunked output-handler loop only partially mitigates. The abort handshake is correct but not instantaneous, so a deployment with long stop-string-terminated generations always burns some tokens past the stop, and under speculative decoding (Chapter 13) it can be several. And the GPU-resident sampled-token shortcut interacts delicately with pipeline parallelism, where the first and last stages don’t share memory and the scheduler has to ferry tokens back the long way, a caveat the code calls out by name.

Two threads from this chapter feed directly forward. The rejection-sampler machinery glimpsed in parse_output is the verification step that Chapter 13 builds speculative decoding on: the sampler already knows how to accept-or-reject a batch of candidate tokens. And the allowed-token mask applied right at the top of the sampler is the GPU-side half of the CPU-computes-mask / GPU-applies-mask handshake that Chapter 14 uses for grammar-constrained generation, which must additionally roll that mask back over speculative tokens the verifier rejected. The sampler, in other words, is not the end of the pipeline. It’s the seam where the next two chapters’ techniques splice in.

Further reading

  • The Curious Case of Neural Text Degeneration — arXiv:1904.09751 — introduces top-p/nucleus sampling and explains why adaptive-nucleus truncation beats greedy/top-k for open-ended generation; the provenance of vLLM’s default sampler.

Quantization: weights, KV cache, and activations

In Chapter 3 we established the asymmetry that drives this whole book: a decode step is memory-bound. To emit one token for one sequence, the GPU streams the entire model’s weights from HBM through the compute units, does a trivially small amount of arithmetic per byte, and writes one token back. The arithmetic units sit mostly idle, waiting on the memory system. Chapter 4 then showed that the second thing the GPU must read on every step is the KV cache, which grows linearly with context and is what actually caps concurrency.

Put those two facts together and a blunt lever falls out. If decode time is dominated by bytes read, then reading fewer bytes buys you time almost one-for-one. Halve the bytes of the weights and a memory-bound decode step gets close to $2\times$ faster; shrink the KV cache and you fit more sequences in the same HBM, which is more concurrency at the same latency. Quantization is the family of techniques that does exactly this: store numbers in fewer bits than the BF16 the model was trained in.

A word on the units, since the whole chapter turns on them. BF16 (“brain floating point, 16-bit”) is the format models are typically trained in: two bytes per number, with enough exponent range to represent both tiny gradients and large activations. FP8 is a one-byte floating-point format; it comes in two flavors that trade range against precision, e4m3 (4 exponent bits, 3 mantissa bits: more precision, narrower range) and e5m2 (5 exponent bits, wider range, less precision). INT4 and INT8 are 4-bit and 8-bit integers, which carry no exponent at all and so need an explicit scale (a multiplier) to map the small integer back to the real value it stands for. The arithmetic of the lever is simple: BF16 to FP8 halves the bytes, BF16 to INT4 quarters them, and on a memory-bound step that fraction is roughly the speedup ceiling. The bars below put the two sides of that trade next to each other: bytes per weight element (what you store and stream) and the corresponding decode speedup ceiling (16 bits divided by the format’s bits).

Speedup ceiling is the upper bound for a perfectly memory-bound decode step (BF16 bits / format bits); real speedups fall short of it because of overheads like on-the-fly dequant and scale reads.

That is the optimistic framing. The pessimistic one, which we will spend most of the chapter on, is that “fewer bits” is not a single knob. There are three distinct things you can quantize, each with its own accuracy cost, its own kernel requirements, and its own place in the forward pass: the weights, the KV cache, and the activations. They are usually discussed together and they are genuinely different problems. This chapter walks them in the order their payoff and their difficulty suggest: weights first (biggest, easiest win), KV cache second (the concurrency lever), activations last (the hardest, because they are dynamic).

It helps to see all three on the same picture before we separate them. The diagram below traces a single decode step through one transformer layer and marks the three quantizable tensors at the points where the GPU actually reads or writes them. The weights are read at every linear layer (the GEMMs in attention’s projections and in the MLP). The activations are the intermediate tensors that flow along the arrows from one operation to the next, recomputed fresh on every step. The KV cache is read and written at the attention block, where the new token’s K and V are appended and the whole history is read back. Three different tensors, three different lifetimes: weights live for the whole serving run, the KV cache lives for the duration of one request, and activations live for the span of one operation.

flowchart LR
    X["activation in (per token)"] --> QKV["QKV projection (GEMM)"]
    W1["weights (static)"] -. read .-> QKV
    QKV --> ATT["attention"]
    KV["KV cache (per request)"] -. "read and write" .-> ATT
    ATT --> O["output projection (GEMM)"]
    W2["weights (static)"] -. read .-> O
    O --> MLP["MLP (GEMMs)"]
    W3["weights (static)"] -. read .-> MLP
    MLP --> Y["activation out (per token)"]

Why weights are the natural first target

The weights are the largest single thing read per decode step and, crucially, they are static. They do not change between requests or between steps, so you can quantize them once, offline, and ship the quantized checkpoint. All the hard statistical work, finding scales that minimize error, happens before serving ever starts. This is the regime of post-training quantization, and it is where the canonical research lives.

Two papers define the practical landscape. GPTQ (arXiv:2210.17323) showed you can quantize an LLM’s weights to 3-4 bits in one pass by greedily rounding columns and using second-order (Hessian) information to compensate the not-yet-quantized weights for each rounding error, keeping accuracy where naive rounding falls apart. AWQ (arXiv:2306.00978) made the sharper observation that not all weights matter equally: a small fraction of “salient” weight channels, identified by the magnitude of the activations that flow through them, dominate the error, so you protect those by per-channel scaling and quantize the rest aggressively. Both are worth reading not for the algorithm details (vLLM consumes their output, it does not run them) but for the mental model: 4-bit weight-only quantization is accurate enough for production because the error is concentrated and can be steered, not because rounding to 4 bits is innocuous.

vLLM’s job is the serving half: take a checkpoint someone already quantized with GPTQ, AWQ, or llm-compressor, and run a forward pass that is actually faster. The configuration surface is deliberately small. vllm/config/quantization.py names the schemes a user can ask for:

# vllm/config/quantization.py
QUANT_KEY_NAMES: dict[str, QuantKey] = {
    "fp8_per_tensor_static": kFp8StaticTensorSym,
    "fp8_per_tensor_dynamic": kFp8DynamicTensorSym,
    "fp8_per_token": kFp8DynamicTokenSym,
    "fp8_per_channel_static": kFp8StaticChannelSym,
    "fp8_per_block_static": kFp8Static128BlockSym,
    "fp8_per_block_dynamic": kFp8Dynamic128Sym,
    "mxfp8": kMxfp8Dynamic,
    "mxfp4": kMxfp4Dynamic,
    "int8_per_channel_static": kInt8StaticChannelSym,
}

Read the suffixes carefully, because they encode the real design space along two axes. The first axis is granularity: how many real numbers share a single scale. Recall that a low-bit integer only means something once you multiply it by a scale; the question is how finely you vary that scale across the tensor. per_tensor uses one scale for the entire weight matrix, which is the cheapest to store and apply but coarse, because one multiplier has to serve numbers of wildly different magnitude. per_channel gives each row or column its own scale; per_token gives each token’s activation vector its own scale; per_block slices the tensor into fixed-size blocks (commonly 128 elements) and gives each block a scale. Finer granularity means the scale tracks local magnitude more closely, so the rounding error is smaller, at the cost of storing and reading more scales.

The second axis is timing: when the scale is computed. A static scale is fixed ahead of time, measured during an offline calibration run and baked into the checkpoint, so serving just reads it. A dynamic scale is computed at runtime from the actual tensor about to be quantized, which adapts perfectly to each input but adds a small reduction (find the max, derive the scale) on the critical path. Static is free at serving time but only as good as the calibration data; dynamic is always well-matched to the input but never free. Every quantization decision in this chapter is some point in this granularity-by-timing space, and the suffix on each scheme name tells you exactly where it sits.

The dispatch seam, again

Here is the load-bearing connection to Chapter 9. There we saw attention dispatched behind an opaque custom op so the engine could pick the best kernel for the hardware without the model code knowing. Quantized linear layers do exactly the same thing, for exactly the same reason: a 4-bit weight times a 16-bit activation is not a GEMM any stock library implements, so vLLM carries its own family of mixed-precision GEMM kernels and chooses among them at load time.

The chooser is choose_mp_linear_kernel, and it walks a platform-specific priority list trying each kernel until one says it can handle this layer:

# vllm/model_executor/kernels/linear/__init__.py
for kernel in platform_kernels:
    if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
        ...
        continue
    if (compute_capability is not None
            and kernel.get_min_capability() > compute_capability):
        ...
        continue
    can_implement, failure_reason = kernel.can_implement(config)
    if can_implement:
        return kernel

Source: vllm/model_executor/kernels/linear/__init__.py

The priority order is where the performance engineering lives:

# vllm/model_executor/kernels/linear/__init__.py
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
    PlatformEnum.CUDA: [
        CutlassW4A8LinearKernel,
        MacheteLinearKernel,
        AllSparkLinearKernel,
        MarlinLinearKernel,
        HummingLinearKernel,
        ConchLinearKernel,
        ExllamaLinearKernel,
        TritonW4A16LinearKernel,
    ],
    ...
}

Source: vllm/model_executor/kernels/linear/__init__.py

This is the same pattern as Chapter 9’s attention backend registry: an ordered list of candidates, a capability gate, and a can_implement predicate. The flowchart below traces the chooser’s loop for a single layer. It walks candidates in priority order; for each one it checks three things in turn (is the kernel disabled by an env var, does the GPU’s compute capability meet the kernel’s minimum, and does can_implement accept this layer’s exact dtype/granularity) and returns the first kernel that clears all three. “Compute capability” is NVIDIA’s version number for a GPU architecture (Ampere is 80, Hopper is 90, Blackwell higher still); a kernel that needs 90 simply will not be considered on an 80-class card. Because the list is ordered fastest-first, the chooser naturally lands on the best kernel the hardware can actually run.

flowchart TD
    START["choose_mp_linear_kernel for this layer"] --> NEXT{"more kernels in priority list?"}
    NEXT -->|no| FAIL["raise: no kernel can serve this config"]
    NEXT -->|yes| K["take next kernel"]
    K --> DIS{"disabled by VLLM_DISABLED_KERNELS?"}
    DIS -->|yes| NEXT
    DIS -->|no| CAP{"GPU compute capability meets kernel minimum?"}
    CAP -->|no| NEXT
    CAP -->|yes| IMPL{"can_implement accepts this layer?"}
    IMPL -->|no| NEXT
    IMPL -->|yes| WIN["return this kernel"]

Machete is the Hopper-and-newer path (get_min_capability returns 90); Marlin is the broadly-compatible workhorse that runs on older cards. The names you will hear most in 2026 production, Marlin and Machete, are simply the two entries in this list that win on the most common hardware.

What does the kernel actually do? Machete’s apply is the clearest illustration that “mixed-precision GEMM” is a real, distinct operation:

# vllm/model_executor/kernels/linear/mixed_precision/machete.py
output = ops.machete_mm(
    a=x_2d,
    b_q=w_q,
    b_type=c.weight_type,
    b_group_zeros=w_zp,
    b_group_scales=w_s,
    b_group_size=c.group_size,
)

Source: vllm/model_executor/kernels/linear/mixed_precision/machete.py

The activations a arrive in BF16/FP16; the weight b_q is still packed 4-bit integers with its group scales (w_s) and zero-points (w_zp) alongside. (A “zero-point” is the integer value that represents real-zero; together with the scale it defines the affine map from 4-bit integer back to real number, $\text{real} = \text{scale} \cdot (q - \text{zero_point})$.) The kernel reads the weight as 4 bits per element, dequantizes on the fly inside the GEMM, and accumulates in high precision.

That on-the-fly dequant is the whole trick, and it is worth stating exactly why it works. The expensive thing on a decode step is moving the weight from HBM into the compute units; that move happens at 4 bits per element, so it is $4\times$ cheaper than BF16. Only once a small tile of weights has arrived in fast on-chip memory does the kernel expand it back to high precision and multiply. The memory traffic is 4-bit (the win), but the math is effectively 16-bit (the accuracy). The expansion costs some arithmetic, but on a memory-bound step the arithmetic units were idle anyway, so it is close to free. This is “weight-only” quantization, often written W4A16: weights at 4 bits, activations at 16. It is the safest, most accurate form, and it is exactly the right tool for a memory-bound decode step where bytes are the bottleneck and FLOPs are free.

This is the crucial difference from quantizing activations, which we reach at the end of the chapter. The diagram below contrasts the two. In weight-only (W4A16) the storage is low-precision but the multiply is high-precision, so the kernel pays a dequant before every multiply. In weight-and-activation (for example W8A8) both operands enter the multiply already low-precision, so the tensor cores do genuine low-precision math, which speeds up compute-bound prefill as well, but only on hardware that has the matching low-precision tensor-core path.

flowchart TD
    subgraph WO["W4A16: weight-only"]
        A1["activation: BF16"] --> M1["multiply (high precision)"]
        WQ1["weight: 4-bit in HBM"] --> D1["dequant to BF16 on-chip"]
        D1 --> M1
        M1 --> ACC1["accumulate: BF16"]
    end
    subgraph WA["W8A8: weight and activation"]
        A2["activation: quantize to FP8"] --> M2["multiply (FP8 tensor core)"]
        WQ2["weight: FP8 in HBM"] --> M2
        M2 --> ACC2["accumulate: BF16"]
    end

The plumbing that connects checkpoint to kernel is uniform. AWQ’s linear method, for instance, builds the kernel config and calls the chooser at weight-creation time:

# vllm/model_executor/layers/quantization/awq_marlin.py
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
...
self.kernel = kernel_type(
    mp_linear_kernel_config,
    w_q_param_name="qweight",
    w_s_param_name="scales",
    w_zp_param_name="qzeros",
)

Source: vllm/model_executor/layers/quantization/awq_marlin.py

and apply is then a one-liner that delegates to whatever kernel won, return self.kernel.apply_weights(layer, x, bias). AWQ’s checkpoint quirks (a non-standard 4-bit packing order, _REVERSE_AWQ_PACK_ORDER, repacked in process_weights_after_loading) are absorbed before the kernel ever sees the weights, so the same Marlin/Machete kernels serve GPTQ and AWQ and compressed-tensors checkpoints alike. The checkpoint format is the front door; the kernel is the back end; the seam between them is choose_mp_linear_kernel.

By 2026 the dominant checkpoint format is compressed-tensors, the llm-compressor output format. Its config class is a dispatcher that reads the per-layer quant args and resolves them to a concrete scheme:

# vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
if self._is_nvfp4_format(weight_quant):
    if input_quant is None:
        return CompressedTensorsW4A4Fp4(use_a16=True)
    ...
    return CompressedTensorsW4A4Fp4()
...
if (self._is_wNa16_group_channel(weight_quant, input_quant)
        and (format == CompressionFormat.pack_quantized.value)
        and (weight_quant.num_bits in WNA16_SUPPORTED_BITS)):
    return CompressedTensorsWNA16(...)

Source: vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

The scheme names are the chapter’s three targets made literal: W4A16 (weight-only, the case above), W8A8 (weights and activations at 8 bits), W4A4 (both at 4). The split between W*A16 and W*A8/W*A4 is precisely the line between “I only quantized weights” and “I also quantized activations”, and the rest of the chapter is about why crossing that line is hard.

The KV cache: a separate budget, a separate dtype

Weight quantization shrinks the model. It does nothing for the KV cache, which Chapter 4 identified as the real concurrency limit. So the KV cache gets its own, completely independent knob, kv_cache_dtype, set in vllm/config/cache.py:

# vllm/config/cache.py
CacheDType = Literal[
    "auto", "float16", "bfloat16",
    "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc", "fp8_ds_mla",
    ...
    "int8_per_token_head", "fp8_per_token_head", "nvfp4",
]

Storing KV in fp8 instead of BF16 halves the bytes per cached token, which roughly doubles how much context (or how many concurrent sequences) fit in the same HBM. The curves below show KV-cache footprint growing linearly with context length for a single sequence on a representative 70B-class model (80 layers, 8 KV heads, head dim 128, so $2 \times 80 \times 8 \times 128 = 163{,}840$ bytes per token at BF16): the fp8 line sits at exactly half the slope, so any given HBM budget reaches roughly twice the context. This is orthogonal to weight quantization: you can run BF16 weights with an fp8 cache, or 4-bit weights with a BF16 cache, in any combination.

Illustrative: slopes are exact (fp8 is half of BF16), absolute GiB values assume the 70B-class config above; a different model’s per-token byte count shifts both lines but not their 2:1 ratio.

The accuracy story is also different, and worth being honest about. Weights are quantized once with a careful offline algorithm. KV-cache entries are produced during inference, token by token, so quantizing them means quantizing a fresh tensor on every step with whatever scale you have. The cache is touched twice per step, and it is worth being precise about both touches. On the write, the new token’s freshly computed K and V (in BF16) are divided by their scales and rounded to fp8 before being stored into the paged cache. On the read, the entire stored history is loaded back as fp8 and multiplied by the same scales to recover an approximate BF16 before attention multiplies it against the query. The scale is the bridge in both directions, which is why getting it right matters so much.

vLLM’s BaseKVCacheMethod attaches _k_scale/_v_scale to each attention layer (vllm/model_executor/layers/quantization/kv_cache.py), and those scales are used “to quantize k/v_cache entries before saving them to the cache” and “dequantize … before fetching them”. The write happens inside the paged-cache op from Chapter 4:

# vllm/v1/attention/backends/flash_attn.py
reshape_and_cache_flash(
    key, value,
    key_cache, value_cache,
    slot_mapping,
    self.kv_cache_dtype,
    layer._k_scale,
    layer._v_scale,
)

Source: vllm/v1/attention/backends/flash_attn.py

The cleanest way to get the scale is to calibrate it offline (run a calibration set, record the typical magnitude of K and V per layer, bake k_scale/v_scale into the checkpoint). process_weights_after_loading in kv_cache.py loads exactly those, and warns loudly if they are missing:

# vllm/model_executor/layers/quantization/kv_cache.py
if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
    logger.warning_once(
        "Using KV cache scaling factor 1.0 for fp8_e4m3. "
        "If this is unintended, verify that k/v_scale "
        "scaling factors are properly set in the checkpoint.")

Source: vllm/model_executor/layers/quantization/kv_cache.py

A scale of 1.0 means “no real calibration”, which for fp8_e4m3 (narrow exponent range) risks clipping. That warning is a tripwire for a common production mistake: turning on kv_cache_dtype=fp8 without a calibrated checkpoint and silently losing accuracy. Note also the newer *_per_token_head cache dtypes, where the scale is computed dynamically per token and head inside the kernel at write time, sidestepping the calibration problem at some kernel cost. There is no free lunch here, only a choice of where to pay.

The sequence below traces the full life of one cached entry across two steps, showing where the scale enters on the write and again on the read. Notice that the stored bytes are always fp8 and the scale never lives in the cache itself; it lives on the layer, applied at the boundary each time the kernel crosses between fp8 storage and BF16 math.

sequenceDiagram
    participant Attn as "attention layer"
    participant Cache as "paged KV cache (fp8)"
    Note over Attn: step N (token just computed)
    Attn->>Attn: "compute K,V in BF16"
    Attn->>Cache: "reshape_and_cache_flash: divide by k_scale/v_scale, round to fp8, store"
    Note over Attn: step N+1 (next token)
    Cache->>Attn: "load history as fp8"
    Attn->>Attn: "multiply by k_scale/v_scale to recover BF16"
    Attn->>Attn: "attention(query, dequantized K,V)"

One important caveat ties back to Chapter 9: MLA models store a compressed latent KV, not raw K and V. Quantizing that latent (the fp8_ds_mla dtype) is a different operation from quantizing per-head K/V, which is why the dtype list has model-family-specific entries. The KV-cache lever exists for every architecture, but what exactly you are shrinking depends on the attention design.

Activations: the hard one

That leaves activations, the tensors flowing between layers, recomputed on every forward pass for every token. Quantizing them is the prize because it unlocks genuine low-precision GEMMs (FP8 tensor cores doing FP8-times-FP8 math, not just FP8-storage-then-dequant), which speeds up the compute-bound prefill and not just decode. It is also the hardest, for a reason the research names directly.

SmoothQuant (arXiv:2211.10438) is the paper to read here. Its core finding: activations contain systematic outlier channels whose magnitudes are far larger than the rest, and those outliers wreck naive per-tensor activation quantization, while the weights are comparatively smooth and easy. The reason outliers are so damaging is direct: a per-tensor scale has to be large enough to represent the biggest value in the tensor, so a single huge outlier channel stretches the scale and forces every ordinary value to round to a coarse grid, destroying their precision. Weights, by contrast, are evenly distributed and quantize cleanly.

SmoothQuant’s fix is to migrate the difficulty from the hard side to the easy side. A linear layer computes a product of an activation $X$ and a weight $W$, that is $X \cdot W$. If you divide a troublesome activation channel by some factor $s$ and multiply the matching weight channel by the same $s$, the product is unchanged, since $(X / s) \cdot (s \cdot W) = X \cdot W$, but the activation is now tame and the weight has absorbed the bump. Since the weights were quantization-friendly to begin with, they tolerate the bump well, and now both operands quantize cleanly. This algebraic shuffle is done offline and baked into the checkpoint; it is the conceptual ancestor of AWQ’s salient-channel idea and the reason W8A8 checkpoints can exist at all.

In vLLM, the static-vs-dynamic distinction from the config surface becomes concrete for activations. A static scheme has one activation scale baked in from calibration. A dynamic scheme measures the scale at runtime, per token, just before the GEMM, which is more accurate (it adapts to each input) but adds a quantize step on the critical path. The FP8 method keeps both options visible:

# vllm/model_executor/layers/quantization/fp8.py
ACTIVATION_SCHEMES = ["static", "dynamic"]

Source: vllm/model_executor/layers/quantization/fp8.py

and the apply path is candid that a true FP8 GEMM is not always the route taken: the comment guarding the dispatch reads “we will use BF16 dequant when direct FP8 is not supported.” In the batch-invariant path, for instance, when no CUTLASS FP8 kernel applies, it simply dequantizes the weight back to BF16 and runs an ordinary GEMM:

# vllm/model_executor/layers/quantization/fp8.py
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)

Source: vllm/model_executor/layers/quantization/fp8.py

The point generalizes beyond that one branch: low-precision activation math is hardware-gated, and where the right tensor-core path is unavailable the engine pays for a dequant rather than producing a wrong answer.

The compressed-tensors scheme dispatcher reflects the same reality: W8A8 and W4A4 are real schemes, but they sit behind capability checks and specific kernels (CUTLASS FP8, NVFP4), and a checkpoint that asks for activation quantization the hardware cannot accelerate gets a slower path rather than a wrong answer.

MoE, and where this is still unsettled

Mixture-of-experts models stress every byte argument harder, because their expert weights dominate the parameter count while only a couple of experts fire per token. Quantizing expert weights is therefore the single biggest lever for MoE memory, and it gets a dedicated path, MoeWNA16Config (“W8A16/W4A16 quantization”, vllm/model_executor/layers/quantization/moe_wna16.py) feeding weight-only quantized weights through the fused-MoE kernels. The fact that it is weight-only (A16) is telling: activation quantization inside the expert routing is even harder than in dense layers, so production MoE quant in 2026 is overwhelmingly weight-only.

What remains genuinely unsolved is worth stating plainly. There is no single accuracy metric that survives quantization cleanly: perplexity can look fine while a specific reasoning or code-generation capability degrades, and the degradation is model- and task-specific. Calibration data quality for activation and KV scales is a real, under-tooled operational burden. And the kernel zoo keeps growing (NVFP4, MXFP4, MXFP8), each new format racing hardware support, so the choose_mp_linear_kernel list is a moving target, not a settled design. Quantization buys throughput almost for free in the easy cases and demands careful, honest evaluation in the hard ones.

We have now spent fewer bytes per token. The next chapter spends fewer forward passes per token: speculative decoding verifies several cheaply-drafted tokens in one bandwidth-bound pass, trading the spare compute that quantization just made even more abundant for fewer trips through memory.

Further reading

  • GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers — arXiv:2210.17323 — one-shot second-order weight quantization to 3-4 bits; the accuracy foundation under vLLM’s GPTQ/Marlin path.
  • AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration — arXiv:2306.00978 — protect the small set of salient weight channels identified by activation magnitude; read it to understand why 4-bit weights work.
  • SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models — arXiv:2211.10438 — migrate activation outliers into the weights so both quantize cleanly; the key to why W8A8 activation quantization is feasible.

Speculative decoding: more tokens per forward pass

In Chapter 3 we drew the line that organizes this whole book: a decode step reads the entire model and the KV cache to produce a single token, so it is bandwidth-bound, with the GPU’s arithmetic units mostly idle. Chapter 5 then showed how continuous batching fills those idle FLOPs by stacking many sequences into one step. Speculative decoding attacks the same waste from a different angle, and for a different shape of traffic: instead of needing many concurrent requests to saturate the hardware, it lets a single request consume more of that spare compute, by checking several candidate tokens in the forward pass that would otherwise verify one.

The trick rests on an asymmetry. Producing the next token costs one full memory pass over the weights. But producing the logits for a prefix you already have in hand also costs one memory pass, whether that prefix is one token long or ten. To see why, recall what dominates the cost of a decode step. A transformer forward pass does two kinds of work: it reads the weights out of GPU memory once, and it multiplies those weights against the tokens flowing through. For the handful of tokens we care about here, the matmuls are tiny and the weight read dwarfs everything else, so the time the pass takes is set by how much memory it has to stream, not by how many tokens ride along inside it. Feeding two tokens or ten tokens through that single weight read costs almost the same wall-clock time. That is the lever speculative decoding pulls.

So if you can cheaply guess the next handful of tokens, you can feed [last_token, guess_1, guess_2, ..., guess_k] through the big model in one shot and get, in parallel, the model’s true probability distribution at each of those positions. Call the big model the target (it is the source of truth, the model whose output you are obligated to reproduce) and call the cheap guesser the drafter (it proposes candidates the target then checks). Where the drafter’s guesses match what the target would have produced anyway, you keep them for free, because the single weight read already paid for verifying them. Where they diverge, you fall back to the target’s own choice and discard the rest. You have spent roughly one decode’s worth of bandwidth and harvested up to $k+1$ tokens instead of one, where $k$ is the number of drafted tokens.

The diagram below traces that loop end to end: one cheap draft step proposes $k$ candidate tokens, one target forward pass verifies all of them at once, and a verifier keeps the longest correct prefix before the loop repeats from wherever it stopped.

flowchart TD
    A["target has produced last_token"] --> B["drafter cheaply guesses<br/>guess_1 ... guess_k"]
    B --> C["build prefix<br/>[last_token, guess_1, ..., guess_k]"]
    C --> D["ONE target forward pass<br/>(single weight read)"]
    D --> E["target distribution p at every position"]
    E --> F["verifier: accept longest correct prefix"]
    F --> G["emit accepted drafts plus 1 target token<br/>(bonus if all accepted, else recovered)"]
    G --> A

The two hard questions are: who does the guessing, and how do you keep the guessed tokens without changing the model’s output distribution? vLLM answers the first with a family of proposers under vllm/v1/spec_decode/ (the “drafter” box above), and the second with a single rejection sampler in vllm/v1/sample/rejection_sampler.py (the “verifier” box). The rest of this chapter is those two halves and the scheduler glue that joins them.

The contract: drafting must be free, verification must be exact

The foundational result is Fast Inference from Transformers via Speculative Decoding (arXiv:2211.17192). Before the rule itself, fix two pieces of notation that recur throughout this section. Write $q(x)$ for the probability the drafter assigned to a token $x$ when it proposed it, and $p(x)$ for the probability the target assigns to that same token now that it has verified it. The drafter is trying to imitate the target, so $q$ is an approximation of $p$; the whole accept/reject machinery exists to correct the gap between them. The paper’s contribution is a rejection-sampling scheme that takes draft tokens from any cheap proposal distribution $q$ and accepts or rejects them against the target distribution $p$ such that the tokens you emit are distributed exactly as if you had sampled from $p$ directly. This is the load-bearing guarantee: speculation is a pure latency optimization, not an approximation. Read it first; everything else in this chapter is an engineering elaboration of its accept/reject rule.

vLLM names the rule directly. From the class docstring in vllm/v1/sample/rejection_sampler.py:

class RejectionSampler(nn.Module):
    """
    The implementation strictly follows the algorithm described in
        https://arxiv.org/abs/2211.17192.
    ...
    bonus tokens:
        If all proposed tokens are accepted, the bonus token is added to the
        end of the sequence. The bonus token is only sampled from the target
        probabilities.
    """

The “bonus token” is why you get $k+1$ and not just $k$: the forward pass over the drafted prefix also yields the logits for the position after the last draft, so if every draft is accepted you get one extra token sampled straight from the target for nothing.

The accept/reject decision lives in a Triton kernel. The greedy case is the easiest to see — for greedy sampling, “the target distribution” is just its argmax, so a draft is accepted exactly when it equals what the target would have picked. From rejection_greedy_sample_kernel in the same file:

            else:
                token_id = target_argmax_id
                rejected = draft_token_id != target_argmax_id
            tl.store(
                output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
                token_id,
            )

Source: vllm/v1/sample/rejection_sampler.py

Notice the kernel walks positions left to right and stops at the first rejection (if not rejected: guards each step). A draft is a chain: each guess was produced assuming every earlier guess was correct, so once the target diverges at position i, every later draft was conditioned on a token the target just rejected and is now meaningless. You cannot keep guess 5 if guess 3 was wrong, because guess 5 only made sense in a world where guess 3 stood. So the verifier accepts a contiguous prefix and discards the entire tail. This is why acceptance length, not raw draft length, is the metric that matters: drafting ten tokens does you no good if the chain reliably breaks at the second one.

The diagram below traces that left-to-right walk for a four-token draft where the target diverges at position 2. Positions 0 and 1 match and are kept; position 2 mismatches, so it is replaced by the target’s own token and the walk stops, throwing away position 3 even though it was already verified in the same forward pass.

flowchart LR
    P0["pos 0: draft == target?"] -->|"yes, accept draft"| P1["pos 1: draft == target?"]
    P1 -->|"yes, accept draft"| P2["pos 2: draft == target?"]
    P2 -->|"no, reject"| R["emit target's token here<br/>STOP the walk"]
    R --> X["pos 3: discarded<br/>(was conditioned on rejected pos 2)"]

The random (temperature > 0) case is where 2211.17192 earns its keep. From rejection_random_sample_kernel:

                target_prob = tl.load(
                    target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
                )
                # NOTE(woosuk): While the draft probability should never be 0,
                # we check it to avoid NaNs. If it happens to be 0, we reject.
                accepted = draft_prob > 0 and target_prob / draft_prob >= uniform_prob

Source: vllm/v1/sample/rejection_sampler.py

Accept the draft token $x$ with probability $\min\left(1, p(x)/q(x)\right)$. The intuition: if the target wanted $x$ at least as much as the drafter did ($p(x) \geq q(x)$), the draft is always kept, because the drafter did not over-propose it. But if the drafter over-proposes $x$ relative to the target’s true taste ($p(x) < q(x)$), you keep it only with probability $p(x)/q(x)$, scaling down exactly the excess. That is the accept step, and on its own it would emit x too rarely for the tokens the target slightly disfavored and never emit tokens the drafter happened not to guess at all, so the distribution would be biased.

The rejection step repairs that bias. On rejection you do not simply resample from $p$ — that would double-count the probability mass the accept step already spent. Instead you sample from the residual distribution $\max(p - q, 0)$, normalized. Read that residual as “the part of the target’s distribution the drafter under-served”: wherever the target wanted a token more than the drafter proposed it, the leftover mass $p - q$ lives there, and drawing from it precisely fills the gap the accept step left. Add the two steps together and the token you finally emit is distributed exactly as $p$. vLLM precomputes those “recovered” tokens (the residual draws) in sample_recovered_tokens before the kernel runs, using a Gumbel-max trick over max(target_prob - draft_prob, 0.0), so the kernel can grab one with a single load when it needs it. The math is subtle; the payoff is that the output stream is provably indistinguishable from ordinary sampling. A reviewer who is nervous about correctness can stop worrying here: this is the whole point of the design.

One detail worth flagging for n-gram drafting, which has no probabilities of its own: the kernel falls back to draft_prob = 1 when NO_DRAFT_PROBS is set, which reduces the random rule to “accept iff target_prob >= uniform_prob.” It still samples correctly; it just cannot claim the draft agreed with the target, so its acceptance rate is whatever the target happens to assign.

Who drafts: a zoo of proposers behind one interface

The proposers in vllm/v1/spec_decode/ differ entirely in how they produce candidates and not at all in how those candidates are verified. They sit on a rough spectrum from “no model at all” to “a second neural network.” The choice is a trade between how cheap the drafter is to run and how often its guesses are accepted: a free CPU-side lookup wins big when it hits but contributes nothing when it misses, while a learned drafter costs real GPU time per step but lands far more of its drafts. The right metric is not raw draft length but the expected number of accepted tokens per step: if each draft position $i$ is accepted with probability $\alpha_i$, a step that drafts $k$ tokens emits on average

$$\mathbb{E}[\text{tokens per step}] = 1 + \sum_{i=1}^{k} \prod_{j=1}^{i} \alpha_j$$

tokens (the leading $1$ is the guaranteed target token, and each draft only counts if every earlier draft in its chain was also accepted). That formula is the whole economics of speculation in one line, and it is sharply nonlinear: the curve below evaluates it for a constant per-position acceptance rate $\alpha$ and shows expected tokens per step climbing as $\alpha$ rises. Two things jump out. First, the payoff is convex in $\alpha$, so the last stretch of acceptance is worth far more than the first. Second, a longer draft $k$ only helps when acceptance is already high, because every extra position is gated behind the product of all the ones before it. Drafting eight tokens at $\alpha = 0.5$ yields about two per step, barely more than drafting two; drafting eight at $\alpha = 0.9$ yields over six.

Computed exactly from the formula above under the simplifying assumption that every draft position shares one acceptance rate $\alpha$; real drafters have a different $\alpha_i$ at each position, which the next chart shows.

The diagram below lays out that spectrum; the prose then walks it left to right.

flowchart LR
    L["cheaper to run,<br/>no GPU cost"] --> A["n-gram / prompt lookup<br/>(no model, copies past text)"]
    A --> B["suffix decoding<br/>(suffix tree, dynamic depth)"]
    B --> C["Medusa<br/>(parallel heads on target)"]
    C --> D["EAGLE / MTP<br/>(autoregressive feature drafter)"]
    D --> E["draft model<br/>(small standalone network)"]
    E --> R["heavier drafter,<br/>real GPU latency per step"]

The cheapest is n-gram / prompt lookup (ngram_proposer.py). An n-gram here just means a short run of $n$ consecutive tokens. The proposer does not run a model at all; it takes the last few tokens the request has produced (the current suffix), searches earlier in the same context for a place where that exact run appeared before, and proposes whatever tokens followed it last time as the draft. If the model said “the quick brown fox” once and has just produced “the quick brown” again, the obvious guess is “fox.” This is shockingly effective for tasks with verbatim copying — code editing, RAG with quoted context, summarization that echoes the source. The core is a Knuth-Morris-Pratt-style longest-suffix-match (KMP is the classic linear-time string-search algorithm), JIT-compiled with numba and run across the batch in parallel. The proposer’s docstring states the rule plainly:

    """
    Find the longest n-gram which matches the suffix of the given tokens
    whose length is within [min_ngram, max_ngram] (inclusive).

    If found, we will extract k right after the matched ngram.
    """

Because it costs nothing but CPU cycles, n-gram drafting is essentially free latency upside when it hits, and a no-op when it misses. The same file is careful never to draft past the model length (k = min(k, max_model_len - total_token)), which is the kind of boundary bookkeeping every proposer has to get right.

A step up is suffix decoding (suffix_decoding.py), which generalizes n-gram lookup to a per-prompt suffix tree maintained across requests (it wraps Arctic Inference’s implementation and is described in arXiv:2411.04975). It can speculate a dynamic number of tokens per step depending on how confident the tree is, rather than a fixed $k$.

Then come the learned drafters. Medusa (medusa.py) is the simplest of these. Its contribution, Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads (arXiv:2401.10774), is to bolt several small feed-forward “heads” onto the target model itself, each trained to predict the token at offset +1, +2, +3, and so on from the current hidden state. There is no separate model to run autoregressively — one forward pass of the target produces a hidden state, and the heads fan it out into several draft tokens at once:

        # Generate blocks and compute logits
        blocks = self.model(target_hidden_states)
        logits = self.model.compute_logits(blocks)

        # Compute argmax for each Medusa head and stack into a single tensor
        # Shape: [batch_size, num_heads]
        draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)

The catch, which Medusa accepts, is that the +2 head is predicting from the same hidden state as the +1 head, with no knowledge of what +1 actually emitted. The drafts are not properly conditioned on each other, so acceptance falls off fast with depth.

EAGLE fixes exactly that. EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty (arXiv:2401.15077) argues that the right thing to draft autoregressively is not tokens but the target’s feature (hidden state) sequence, which is smoother and more predictable, and then to read tokens off those features. The EAGLE head is a small transformer that does run autoregressively, consuming the target’s last hidden state to draft a properly-conditioned chain. In vLLM this is almost pure configuration on top of a shared base — the entire proposer is:

class EagleProposer(SpecDecodeBaseProposer):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        runner=None,
    ):
        super().__init__(
            vllm_config,
            device,
            pass_hidden_states_to_model=True,
            runner=runner,
        )

Source: vllm/v1/spec_decode/eagle.py

That pass_hidden_states_to_model=True flag is the whole distinction from a generic draft model (draft_model.py, the classic 2211.17192 setup of a small standalone model of the same family), which sets it to False and works from tokens alone. EAGLE and its MTP variants are the production default for most 2026 open models that ship a trained draft head; the long list of *_mtp method names in vllm/config/speculative.py is the roster of models that bake the draft head into their release.

All of this is selected through one config surface. SpeculativeConfig in vllm/config/speculative.py carries num_speculative_tokens (the $k$ above), the method, and the optional draft model, and can auto-detect the method from the model when you do not name it.

Fitting speculation into the token-budget scheduler

Here is where Chapter 5 pays off. We promised that the token-budget scheduler treats prefill, decode, chunking, and speculation as one problem, and the design note in vllm/v1/core/sched/scheduler.py says so outright:

        # NOTE(woosuk) on the scheduling algorithm:
        # There's no "decoding phase" nor "prefill phase" in the scheduler.
        # Each request just has the num_computed_tokens and
        # num_tokens_with_spec. num_tokens_with_spec =
        # len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
        # At each step, the scheduler tries to assign tokens to the requests
        # so that each request's num_computed_tokens can catch up its
        # num_tokens_with_spec.

Unpack the two counters in that note. num_computed_tokens is how many tokens of this request the model has actually run through a forward pass; num_tokens_with_spec is how many tokens the request wants run, including the speculative tail the proposer just appended as spec_token_ids. A speculating request is simply one whose num_tokens_with_spec is a few tokens ahead of where it would otherwise be. The scheduler does not have a special “verify a draft” code path; it allocates num_new_tokens = num_tokens_with_spec - num_computed_tokens slots, exactly as it would for a chunk of prefill, and the same forward pass that would have decoded one token now happens to cover last_token plus the $k$ drafts. This is the elegance of the V1 design: speculation rides on the same token-level rate limiter you already understand, and the same KV-block allocation, with no separate machinery.

The interesting half is the cleanup after the model runs, because the scheduler advanced num_computed_tokens optimistically. Before the forward pass it had no way to know how many drafts would be accepted, so it assumed the best case and counted all $k$ drafts as computed. The forward pass then verified $k$ drafts, the rejection sampler accepted some prefix of them, and now num_computed_tokens is too high by exactly the number of rejected drafts — those positions ran through the model but their tokens never became part of the real sequence. The fix is to roll the counter back. From update_from_output:

            if scheduled_spec_token_ids and generated_token_ids:
                num_draft_tokens = len(scheduled_spec_token_ids)
                num_accepted = len(generated_token_ids) - 1
                num_rejected = num_draft_tokens - num_accepted
                ...
                if request.num_computed_tokens > 0:
                    request.num_computed_tokens -= num_rejected

Source: vllm/v1/core/sched/scheduler.py

That num_accepted = len(generated_token_ids) - 1 is the bonus token at work: the sampler always returns at least one real token (the bonus, or the recovered token at the rejection point), so everything beyond that first one is an accepted draft. The rejected tail is rolled back by decrementing num_computed_tokens, and the KV slots those rejected tokens wrote are simply overwritten on the next step. No preemption, no recompute — speculation degrades to ordinary decode when acceptance is zero, costing only the wasted draft compute.

The sequence diagram below traces one full speculative step across the three actors, including the optimistic advance and the corrective rollback. Read it top to bottom as a single iteration of the loop from the first diagram in this chapter.

sequenceDiagram
    participant S as Scheduler
    participant P as Proposer drafter
    participant M as Model and RejectionSampler
    P->>S: append spec_token_ids (k drafts)
    S->>S: advance num_computed_tokens by k (optimistic)
    S->>M: allocate slots for last_token + k drafts
    M->>M: one forward pass over the whole prefix
    M->>M: walk drafts, accept prefix, add bonus token
    M->>S: generated_token_ids (accepted + 1 bonus)
    S->>S: num_rejected = k - (len - 1)
    S->>S: roll back num_computed_tokens by num_rejected
    Note over S,M: rejected KV slots overwritten next step

The engine also keeps score. SpecDecodingStats in vllm/v1/spec_decode/metrics.py tracks acceptance both in aggregate and per position:

    def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int):
        self.num_drafts += 1
        self.num_draft_tokens += num_draft_tokens
        self.num_accepted_tokens += num_accepted_tokens
        assert num_accepted_tokens <= self.num_spec_tokens
        for i in range(num_accepted_tokens):
            self.num_accepted_tokens_per_pos[i] += 1

The per-position array is the diagnostic you will actually use in production. It tells you the acceptance decay curve: how often position 0 is accepted versus position 3. A curve that collapses after the first token means your $k$ is set too high for this drafter on this traffic, and you are burning verification compute on drafts that never land. The chart below sketches that decay for two kinds of drafter on the same traffic: a strong feature drafter (EAGLE-style) whose per-position acceptance falls slowly, and a weak drafter (a shallow n-gram match, say) that lands the first guess often but collapses by position two or three. The shaded region under each is exactly the $\alpha_i$ sequence that feeds the chained product in the formula above, so a curve that has gone flat means the drafts at those positions are pure wasted verification.

Illustrative shapes, not measured: the exact numbers are model- and traffic-dependent, and the num_accepted_tokens_per_pos array is precisely what you would log to recover the real curve.

Where it helps, where it hurts, what is unsolved

Speculative decoding is not free throughput; it is a trade. You spend extra FLOPs (the draft, plus verifying tokens you may discard) to buy fewer memory passes per accepted token. That trade is favorable precisely when the GPU has spare FLOPs — i.e. at low-to-moderate batch sizes, where decode is most bandwidth-starved. At high concurrency, where continuous batching has already filled the compute, speculation competes with real requests for the same FLOPs and can reduce aggregate throughput even as it lowers single-request latency. This is the central tension to internalize: speculation and batching both target idle compute, so they are partly substitutes, and the right num_speculative_tokens is load-dependent. A server tuned for a packed batch should speculate little or not at all; one serving a latency-sensitive trickle should speculate aggressively.

The open problems are real. Dynamic speculation length — letting $k$ track the live acceptance curve and the current batch size rather than sitting at a fixed config value — is only partly solved (suffix decoding’s variable depth is one attempt). Draft quality is the other frontier: EAGLE-style feature drafting raised acceptance dramatically over plain draft models, but every gain in the drafter costs drafter latency, and the optimum is model- and task-specific. And there is an interaction we have deliberately deferred: when output is constrained to a grammar, the bitmask that bans illegal tokens must be applied to draft positions too, and rolled back over rejected drafts so the FSM state stays consistent. That handshake is the explicit reason Chapter 14 builds on this one — structured generation has to teach the constraint engine to speak speculation’s accept/reject language. We pick it up there.

Further reading

  • Fast Inference from Transformers via Speculative Decoding — arXiv:2211.17192 — the rejection-sampling rule vLLM implements verbatim; read it to understand why accepted tokens are distributed exactly as the target.
  • EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty — arXiv:2401.15077 — draft at the feature (hidden-state) level autoregressively rather than at the token level; the production-default drafter for trained heads.
  • Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads — arXiv:2401.10774 — bolt parallel prediction heads onto the target model itself; simplest learned drafter, with acceptance that decays fast with depth.
  • SuffixDecoding: Extreme Speculative Decoding for Emerging AI Applications — arXiv:2411.04975 — a model-free approach that generalizes n-gram lookup to per-prompt suffix trees with dynamic speculation depth.

Structured and guided generation

A surprising fraction of production LLM traffic does not want free-form prose. It wants JSON that parses, a tool call whose arguments match a schema, an enum that is exactly one of three strings, a date in YYYY-MM-DD. The moment a model output feeds another program rather than a human reader, “mostly valid” is a bug. Retrying until the model happens to emit well-formed output is the naive fix, and it is expensive: you pay full decode latency for a response you then throw away.

The better fix turns out to be small and surgical, and it slots neatly into machinery we already have. Recall from Chapter 11 that the sampler is a pipeline operating on the last-position logits, the vector of one raw score per vocabulary token that the model produces before any token is chosen, and that the single mandatory GPU-to-CPU sync is the sampled token id. If we could reach into that pipeline one step earlier and forbid the tokens that would violate our grammar, the model would be physically unable to produce malformed output. That is the whole idea of structured generation: at every decode step, compute which tokens are legal given what has been emitted so far, and set the logits of all the illegal ones to negative infinity before sampling. A token whose logit is negative infinity has zero probability after the softmax, so the sampler can never pick it, no matter the temperature or sampling method. The legality test comes from a grammar, a formal description of every valid string (a JSON-schema, an EBNF rule set, a fixed list of choices), which the engine compiles into a finite-state machine (FSM): a little automaton whose current state encodes “given what has been emitted so far, here is exactly the set of tokens that may come next.”

The catch is where that “compute which tokens are legal” work happens. Walking that FSM over a 150,000-entry vocabulary to decide one bit per token is pure CPU bookkeeping, and the GPU is the scarce resource we spent the entire first half of this book learning not to stall. So the engineering problem is a scheduling problem dressed as a parsing problem: keep the FSM work off the GPU’s critical path, hand the GPU a finished bitmask, and have it do nothing more than a cheap masked write. And because Chapter 13 taught the engine to speculate several tokens ahead in one forward pass, the mask machinery has to cope with tokens that get drafted, masked, and then rejected, rolling the grammar’s state back as if they never happened.

The diagram below shows the division of labor that the rest of this chapter unpacks. The CPU walks the grammar and produces a bitmask; the GPU does its normal forward pass in parallel and then applies that mask in a single cheap kernel right before sampling. The two lanes overlap, so the masking adds almost nothing to per-step latency.

flowchart LR
    subgraph CPU["CPU lane (off the GPU critical path)"]
        G["grammar FSM: current state"] --> F["fill_next_token_bitmask: 1 bit per vocab token"]
        F --> M["packed bitmask (int32 per 32 tokens)"]
    end
    subgraph GPU["GPU lane"]
        FW["forward pass: produce logits"] --> AP["apply_token_bitmask_inplace: illegal logits set to -inf"]
        AP --> S["sampler: pick a legal token"]
    end
    M -->|"non_blocking host-to-device copy"| AP
    S -->|"sampled token id"| G

A mask is a bitmask

Constraining the vocabulary is conceptually a boolean vector of length vocab_size: one bit per token, set if the token is currently allowed. Storing 150,000 separate booleans would waste memory and bandwidth, so vLLM packs them, 32 tokens per int32, the way a compiler packs flags into bit fields. Allocating one such vector per token would also thrash the allocator, so the backends carve out the whole batch’s worth at once, a two-dimensional buffer of shape (sequences, vocab-words). The xgrammar backend’s allocator is a one-liner:

# vllm/v1/structured_output/backend_xgrammar.py
def allocate_token_bitmask(self, max_num_seqs: int):
    return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)

Source: vllm/v1/structured_output/backend_xgrammar.py

A grammar is compiled once per request into a matcher object, and at each step the matcher fills one row of that bitmask:

# vllm/v1/structured_output/backend_xgrammar.py
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
    self.matcher.fill_next_token_bitmask(bitmask, idx)

Source: vllm/v1/structured_output/backend_xgrammar.py

fill_next_token_bitmask is the expensive part, and it is the part XGrammar exists to make fast. The naive cost is brutal: for each of the 150,000 vocabulary tokens, ask “could the grammar’s current state accept this token’s characters?”, which is a scan over the whole vocabulary every single step. XGrammar’s trick, summarized in the XGrammar paper (below), is to notice that most of that answer does not depend on the exact FSM state. A token like { or , is either always-structurally-plausible or never, regardless of where in the grammar you are; only a minority of tokens are “context-dependent” and need a live check. XGrammar precomputes the always/never answer per FSM state ahead of time, so the per-step work shrinks to a cache lookup over the precomputed set plus a small live check over the few context-dependent tokens, rather than a fresh scan. The point for us is that this is a CPU computation reading CPU-resident grammar state, with no GPU involvement whatsoever, which is exactly what lets it hide behind the GPU’s forward pass.

The curve below makes the difference concrete: the naive per-step cost is one live check per vocabulary token, so it grows linearly as vocabularies climb toward and past 150,000 tokens, while XGrammar’s per-step live work tracks only the small context-dependent set (a few percent of the vocabulary) and stays almost flat, which is what shrinks the per-step bookkeeping enough to hide behind a forward pass.

Illustrative: the naive line is one check per token by definition; the XGrammar line assumes a representative ~3% context-dependent fraction (the precomputed always/never tokens cost only a cache lookup, not a live check). The real fraction depends on the grammar and tokenizer.

Two backend details are worth pausing on because they reappear later. First, the matcher is created with a rollback budget tied to speculation:

# vllm/v1/structured_output/backend_xgrammar.py
return XgrammarGrammar(
    matcher=xgr.GrammarMatcher(
        ctx,
        max_rollback_tokens=self.num_speculative_tokens,
    ),
    ...
)

Source: vllm/v1/structured_output/backend_xgrammar.py

That max_rollback_tokens is not incidental. It is exactly the number of draft tokens the speculator may propose, because every one of them might be rejected and need un-doing. Second, compilation is not free. Compiling a JSON schema or EBNF grammar into a matcher can take tens of milliseconds, which is catastrophic if it happens inline on the engine’s step loop. vLLM moves it off the loop entirely.

Keeping compilation off the loop: the WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR state

When a structured request arrives, the manager submits its grammar compilation to a thread pool and stores the resulting Future on the request rather than blocking:

# vllm/v1/structured_output/__init__.py
if self._use_async_grammar_compilation:
    grammar = self.executor.submit(self._create_grammar, request)
else:
    grammar = self._create_grammar(request)
request.structured_output_request.grammar = grammar

Source: vllm/v1/structured_output/__init__.py

The pool is sized deliberately for CPU-bound work, half the cores ((cpu_count() + 1) // 2), not the Python default of five times the core count that suits I/O:

# vllm/v1/structured_output/__init__.py
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)

Source: vllm/v1/structured_output/__init__.py

While that future is pending, the scheduler from Chapter 5 must not run the request, or it would reach the sampler with no mask to apply, and the model could emit a malformed first token before the grammar even exists. This is the purpose of the WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR request status: it is a blocked-waiting state, parked in a side queue, that the scheduler skips until the grammar is ready. The state diagram below traces a structured request from arrival (the node labelled WAITING_FOR_GRAMMAR is that WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR status, abbreviated to fit): it is held out of the runnable WAITING queue while its compilation future is pending, and is only promoted once a poll finds the future resolved.

stateDiagram-v2
    [*] --> WAITING_FOR_GRAMMAR: structured request arrives, compile submitted to thread pool
    WAITING_FOR_GRAMMAR --> WAITING_FOR_GRAMMAR: scheduler poll, future still pending, skip
    WAITING_FOR_GRAMMAR --> WAITING: poll finds grammar ready, promote
    WAITING --> RUNNING: scheduler admits request to a batch
    RUNNING --> RUNNING: decode step, mask, sample, advance FSM
    RUNNING --> FINISHED_ERROR: accept_tokens fails, mask and FSM disagree
    RUNNING --> [*]: grammar terminated or stop token

The promotion check is just a poll of the future’s completion:

# vllm/v1/core/sched/scheduler.py
if request.status == RequestStatus.WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR:
    structured_output_req = request.structured_output_request
    if not (structured_output_req and structured_output_req.grammar):
        return False
    request.status = RequestStatus.WAITING
    return True

Source: vllm/v1/core/sched/scheduler.py

The comment in StructuredOutputManager.__init__ is candid about why this is conditional on the executor backend: under external_launcher mode there is one scheduler per TP rank, and asynchronous compilation would let the WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR -> WAITING transition fire at different times on different ranks, breaking the lockstep determinism those ranks rely on. So in that mode compilation is forced synchronous. This is a recurring flavor of distributed-systems tax: an optimization that is purely local on one process becomes a consistency hazard the moment the same decision must be made identically on many.

The CPU-computes / GPU-applies handshake

Now the steady-state path. Each step, after the scheduler decides the batch, it asks the structured-output manager to build a compacted bitmask covering only the structured requests, in a known order:

# vllm/v1/core/sched/scheduler.py
bitmask = self.structured_output_manager.grammar_bitmask(
    self.requests,
    structured_output_request_ids,
    scheduler_output.scheduled_spec_decode_tokens,
)
return GrammarOutput(structured_output_request_ids, bitmask)

Source: vllm/v1/core/sched/scheduler.py

grammar_bitmask fills rows by calling each grammar’s fill_bitmask, and for large batches without speculation it shards that filling across a second, dedicated thread pool in chunks of 16 requests:

# vllm/v1/structured_output/__init__.py
if len(batch) == self.fill_bitmask_parallel_batch_size:
    promises.append(self._async_submit_fill_bitmask(batch))
    batch = []

Source: vllm/v1/structured_output/__init__.py

Then comes the seam that matters most. The completed mask is converted to a NumPy array, deliberately, because the comment notes that “serialization of np.ndarray is much more efficient than a tensor” when shipping it to the GPU workers, and on the worker the application is a single kernel:

# vllm/v1/structured_output/utils.py
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
    logits.device, non_blocking=True
)
...
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)

Source: vllm/v1/structured_output/utils.py

Note the non_blocking=True host-to-device copy and the in-place kernel that sets masked logits to negative infinity. The non_blocking flag lets the small bitmask copy be queued on the GPU stream without the CPU stalling to wait for it, so it rides alongside the work already in flight. The GPU does the absolute minimum: receive a small bitmask, AND it into the logits. All the grammar walking already happened on CPU, overlapped with the GPU’s forward pass for this very step. The sequence diagram below traces one decode step and makes the overlap explicit: while the GPU is busy with the forward pass, the CPU is already filling the bitmask for the same step, so when the logits land the mask is waiting.

sequenceDiagram
    participant Sched as "Scheduler (CPU)"
    participant SOM as "StructuredOutputManager (CPU)"
    participant GPU as "Model runner (GPU)"
    Sched->>GPU: launch forward pass for the batch
    Sched->>SOM: grammar_bitmask(requests, ids, spec tokens)
    Note over SOM: fill_bitmask per request, sharded across a thread pool
    SOM-->>Sched: packed bitmask as np.ndarray
    GPU-->>Sched: logits ready
    Sched->>GPU: apply_grammar_bitmask(logits, bitmask)
    Note over GPU: non_blocking copy then in-place AND, illegal logits to -inf
    GPU->>GPU: sample a legal token
    GPU-->>Sched: sampled token ids
    Sched->>SOM: accept_tokens, advance each FSM

The call site lives right before sampling in the model runner, exactly where Chapter 11 said the logits live:

# vllm/v1/worker/gpu_model_runner.py
if grammar_output is not None:
    apply_grammar_bitmask(
        scheduler_output, grammar_output, self.input_batch, logits
    )

Source: vllm/v1/worker/gpu_model_runner.py

One subtlety in apply_grammar_bitmask repays attention, and it is purely a row-alignment problem. The logits tensor has one row per request in the batch, in the batch’s order, mixing structured and free-form requests freely. The bitmask, by contrast, was built in a different order (sorted by structured_output_request_ids) and contains only the structured requests, because those are the only ones with a grammar to fill. The kernel ANDs row i of the mask into row i of the logits, so if the two orderings do not match, request A’s constraints land on request B’s logits, silently corrupting one output while the batch looks healthy. To prevent this, the function rebuilds a sorted_bitmask that has the same number of rows as the logits and is aligned to them: every free-form row is filled with -1 (all bits set, meaning every token allowed, so the AND is a no-op), and only the structured rows are overwritten with their real masks. As a fast path, if every logit row turns out to be a structured row, the alignment is already trivial and it skips passing the per-row indices to the kernel altogether. This is the kind of off-by-one bookkeeping any batched per-request transform demands; here it is load-bearing for correctness.

After the FSM: advancing and accepting

Masking is only half of correctness. Masking decides what may come next; advancing records what actually came, so the two must stay in step. Once a token is sampled, the grammar’s FSM has to advance so that the next step’s mask reflects the new state, otherwise the grammar would keep offering the same opening tokens forever. That advance happens back in the scheduler’s update_from_output, gated by should_advance. The gate exists because some models emit a reasoning phase first, a free-form scratchpad of “thinking” tokens before the structured answer begins, and those tokens must not drive the grammar (the grammar describes the answer, not the scratchpad). So should_advance returns false during reasoning and the FSM stays put until the real answer starts:

# vllm/v1/core/sched/scheduler.py
if new_token_ids and self.structured_output_manager.should_advance(request):
    struct_output_request = request.structured_output_request
    ...
    if not struct_output_request.grammar.accept_tokens(req_id, new_token_ids):
        ...
        request.status = RequestStatus.FINISHED_ERROR

Source: vllm/v1/core/sched/scheduler.py

In the happy path accept_tokens always succeeds, because the mask guaranteed the sampled token was legal. The defensive failure branch terminates the request, which is the honest thing to do: a rejection here means the mask and the FSM disagreed, a bug, not a recoverable condition.

Rolling back the mask over rejected speculative tokens

This is where the chapter’s dependence on Chapter 13 becomes concrete. With speculation, a single forward pass verifies several drafted tokens at once: a small draft model proposes a few likely next tokens, and the big model checks them all in one pass, keeping the longest accepted prefix. Each draft position needs its own mask, because the legal set at position $k+1$ depends on what was emitted at position $k$. To know what tokens are legal after the draft token at position $k$, the FSM must first be advanced as if that token were accepted. So the masks for a speculative step can only be computed by walking the FSM forward through the draft, one hypothetical token at a time. Concretely, grammar_bitmask allocates room for every speculative slot plus a bonus token (the bonus is the extra “free” token a verification pass yields when the whole draft is accepted):

# vllm/v1/structured_output/__init__.py
self._grammar_bitmask = self.backend.allocate_token_bitmask(
    max_batch_size * (1 + max_num_spec_tokens)
)

Source: vllm/v1/structured_output/__init__.py

and the serial spec-decode path walks the draft tokens, advancing the FSM as it fills each position, then rolls the FSM back by exactly the number of advances it made:

# vllm/v1/structured_output/__init__.py
for token in itertools.chain(req_tokens, (-1,)):
    self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),))
    ...
    if apply_bitmask and not grammar.is_terminated():
        accepted = grammar.accept_tokens(req_id, [token])
        ...
        state_advancements += 1
    cumulative_index += 1
if state_advancements > 0:
    grammar.rollback(state_advancements)

Source: vllm/v1/structured_output/__init__.py

The reason for the rollback is delicate, and it is the crux of how masking and speculation coexist. To predict the mask for draft position $k$, the FSM had to be advanced through the drafted tokens of positions $0 \ldots k-1$, hypothesizing that they will be accepted. But the rejection sampler (Chapter 13) has the final say, and it may reject some suffix of those drafts. The FSM cannot commit to the hypothetical advances, because if it did and the drafts were then rejected, the grammar’s state would be ahead of the tokens actually emitted, and every future mask would be wrong. So the manager treats the forward walk as throwaway scaffolding: it speculatively advances purely to compute the masks, counts exactly how many advances it made (state_advancements), and immediately rolls all the way back to the pre-step state. The real advance, over only the tokens the rejection sampler actually accepted, happens afterward in update_from_output via accept_tokens. The -1 sentinel appended to the loop marks the bonus/non-speculative slot; once it is hit the code stops advancing, so it fills that slot’s mask but never walks the grammar past a padding token.

The flowchart below traces the loop for one request: fill a mask, hypothetically advance, repeat per draft token, then unwind every advance so the grammar is back where it started.

flowchart TD
    A["start: FSM at pre-step state, state_advancements = 0"] --> B["fill bitmask for current position"]
    B --> C{"sentinel -1 or grammar terminated?"}
    C -->|"no"| D["accept_tokens([draft token]): hypothetically advance FSM"]
    D --> E["state_advancements += 1"]
    E --> B
    C -->|"yes"| F{"state_advancements > 0?"}
    F -->|"yes"| G["rollback(state_advancements): FSM back to pre-step state"]
    F -->|"no"| H["done; masks ready for verification pass"]
    G --> H

The grammar’s own rollback keeps its bookkeeping honest, decrementing the processed-token count and re-checking termination:

# vllm/v1/structured_output/backend_xgrammar.py
def rollback(self, num_tokens: int) -> None:
    self.matcher.rollback(num_tokens)
    self.num_processed_tokens -= num_tokens
    self._is_terminated = self.matcher.is_terminated()

Source: vllm/v1/structured_output/backend_xgrammar.py

This is why max_rollback_tokens was wired to num_speculative_tokens back at matcher construction: the rollback depth the engine needs is bounded precisely by how far speculation looks ahead. Get that bound wrong and rollback fails; tie it to the spec config and it is correct by construction.

Backends are interchangeable behind one interface

vLLM does not bet on a single grammar engine. The manager dispatches on a configured backend name to XGrammar, Guidance (llguidance), Outlines, or LM Format Enforcer, all implementing the same StructuredOutputBackend / StructuredOutputGrammar pair. The Guidance backend, for instance, presents the identical fill_bitmask shape over a different matcher:

# vllm/v1/structured_output/backend_guidance.py
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
    llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
    self.check_error()

Source: vllm/v1/structured_output/backend_guidance.py

The differences live below the interface: which JSON-schema features each supports (xgrammar rejects multipleOf, uniqueItems, patternProperties, and more in has_xgrammar_unsupported_json_features; guidance rejects patternProperties), how grammars are specified (xgrammar converts Lark, a grammar notation, into EBNF, a standard grammar-description syntax; a simple list of allowed choices becomes a tiny root ::= "a" | "b" grammar, read as “the whole output is either the string a or the string b”), and engine-specific niceties like jump-forward decoding, which both backends note as future work. The uniform interface is what lets the rollback-and-mask machinery above stay backend-agnostic: swap the engine and the scheduler code above does not change.

What is still hard

The handshake is clean, but it is not free, and several rough edges remain. Compilation latency for large recursive schemas can still stall a request’s first token even off the critical path, because the request waits in WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR until its future resolves; pathological grammars can dominate TTFT. The per-step CPU mask fill, even sharded, competes for the same cores as detokenization and scheduling, so under high structured-request concurrency the CPU can become the bottleneck the GPU never was, which is exactly the kind of regression Chapter 21 will teach you to localize. Jump-forward decoding is the obvious next throughput win and is still unimplemented in both major backends here: when the grammar leaves exactly one legal continuation (after "name": a JSON object must continue with a ", for instance), the engine could emit those forced characters directly and skip the model forward pass entirely, since the model has no real choice to make. It is free tokens, but the bookkeeping to splice forced tokens into the KV cache and the sampler path is nontrivial, so it remains future work. And the interaction with reasoning models, where the grammar must hold its fire during a thinking phase and engage only at the answer, is governed by the should_advance / should_fill_bitmask logic and remains visibly under active development, with comments flagging code paths slated for unification.

The deeper lesson generalizes past grammars. Structured generation is a template for any “constrain or bias the next token” feature: do the bookkeeping on the CPU, overlap it with the GPU’s forward pass, hand the GPU a compact mask, and respect speculation by making every hypothetical advance reversible. The next part of the book leaves the single replica behind and asks how a model spreads across many GPUs at all, starting with tensor, pipeline, and expert parallelism.

Further reading

  • XGrammar: Flexible and Efficient Structured Generation Engine for Large Language Models — arXiv:2411.15100 — Precomputes context-independent token sets per FSM state to make per-step vocabulary masking cheap; read it for the data structures that make fill_next_token_bitmask fast enough to hide behind a forward pass.

Distributed inference: tensor, pipeline, and expert parallelism with expert-parallel load balancing

Everything in Part II assumed the model fit on one GPU. That assumption holds for a 7B model in bfloat16, and Chapter 12 stretched it further: quantize the weights to FP8 or 4-bit and a surprisingly large model squeezes onto a single device, with the KV cache from Chapter 4 fighting it for the leftover HBM. But the assumption breaks eventually. A 405B dense model is most of a terabyte of weights before you reserve a single block for KV. A trillion-parameter mixture-of-experts model is far past that. And even when the weights fit, a model can be too slow on one GPU: decode is memory-bandwidth-bound, and one device only has so many bytes per second to read its weights with.

So you split the model across devices. The interesting part is that there is no single way to split it. There are three axes, each cutting the model along a different grain, and each pays for itself in a different communication currency with a different failure mode. Before going deep, it helps to hold the three side by side. Tensor parallelism (TP) cuts across every layer, so the work of one matrix multiply is shared by several GPUs; pipeline parallelism (PP) cuts along the depth of the model, so each GPU owns a contiguous run of whole layers; expert parallelism (EP) exists only for mixture-of-experts models and cuts by which expert, putting different expert sub-networks on different GPUs. The diagram below contrasts the three cuts on the same toy model.

flowchart LR
    subgraph TP["Tensor parallel: shard within a layer"]
        direction TB
        TPL["one matmul"] --> TPa["GPU 0: left columns"]
        TPL --> TPb["GPU 1: right columns"]
    end
    subgraph PP["Pipeline parallel: split by depth"]
        direction TB
        PPa["GPU 0: layers 0 to N"] --> PPb["GPU 1: layers N to 2N"]
    end
    subgraph EP["Expert parallel: split by expert"]
        direction TB
        EPa["GPU 0: experts 0 to 7"]
        EPb["GPU 1: experts 8 to 15"]
    end

This chapter walks the three, grounds each in vLLM’s process-group machinery, and then confronts what makes large MoE serving genuinely hard in production: once experts live on different GPUs, throughput is hostage to how evenly traffic spreads across them. The chapter closes by watching vLLM move experts between GPUs under live load to fix that skew.

One mesh, four axes

Before any layer is sharded, vLLM has to decide which GPU belongs to which group. That bookkeeping lives in initialize_model_parallel, and the cleanest way to understand the whole chapter is to read how it lays out the ranks:

# vllm/distributed/parallel_state.py
# the layout order is: ExternalDP x DP x PP x TP
all_ranks = torch.arange(world_size).reshape(
    -1,
    data_parallel_size,
    pipeline_model_parallel_size,
    prefill_context_model_parallel_size,
    tensor_model_parallel_size,
)  # noqa

Source: vllm/distributed/parallel_state.py

A rank is just a global integer name for one GPU process: in a cluster of world_size GPUs, the ranks run from 0 to world_size - 1. The reshape above takes that flat list of ranks and folds it into a multi-dimensional grid. There are five dimensions here, ordered from slowest-varying to fastest: the trailing -1 absorbs whatever is left over as “external data parallel”, then data parallel (DP), pipeline parallel (PP), a prefill-context-parallel dimension (a fifth axis for splitting long-prompt prefill that this chapter sets aside), and finally tensor parallel (TP) as the innermost, fastest-varying axis. The three axes this chapter cares about — TP, PP, and EP — are all carved from this same grid. Every GPU is one cell in it, and its coordinates along each axis say which TP group, which PP stage, and which DP replica it belongs to. The grid is the single source of truth from which every communication group is derived.

To build a particular parallel group, vLLM transposes the axis it cares about to the end, flattens, and unbinds — that is, it lines up the rank-grid so the axis of interest runs contiguously, then reads off each row as one group of ranks that will talk to each other. Tensor-parallel groups are just consecutive ranks (all_ranks.view(-1, tensor_model_parallel_size)), so ranks 0–7 form one TP group, 8–15 the next, and so on; pipeline groups stride across the PP dimension (all_ranks.transpose(2, 4)...), picking one rank from each stage; the data-parallel and expert-parallel groups come from their own transposes. Each call hands its rank lists to init_model_parallel_group, which wraps them in a GroupCoordinator — the object that owns the actual NCCL communicators (the GPU-to-GPU message channels) and exposes all_reduce, all_gather, send, recv for that group. The docstring at the top of the function is worth internalizing because it dictates topology:

# vllm/distributed/parallel_state.py
# Note that for efficiency, the caller should make sure adjacent ranks
# are on the same DGX box.

Source: vllm/distributed/parallel_state.py

Adjacency is not cosmetic. TP groups are consecutive precisely because TP is the most communication-hungry axis and wants the fattest links (NVLink within a box), while PP, which exchanges far less, can tolerate slower cross-node hops. Hold that ordering in mind; it is the reason the three axes have the cost profiles they do.

Tensor parallelism: split every matmul, pay an all-reduce per layer

Tensor parallelism is the Megatron idea (arXiv:1909.08053): shard the weight matrices within a layer and have every GPU compute a slice of every operation. vLLM implements it as two complementary linear layers. A ColumnParallelLinear splits the weight along its output dimension so each rank owns a vertical slab:

# vllm/model_executor/layers/linear.py
# Divide the weight matrix along the last dimension.
self.output_size_per_partition = divide(output_size, self.tp_size)

Source: vllm/model_executor/layers/linear.py

Its partner, RowParallelLinear, splits along the input dimension, so each rank consumes the matching slice of the previous layer’s sharded output. The cleverness of Megatron’s pairing is that a column-parallel matmul followed by a row-parallel matmul needs communication only at the very end. Walk it through. The first matmul gives each rank its own vertical slab of the intermediate result — no rank has the whole thing, but together they have all of it, with no overlap. That sharded intermediate then flows straight into the second matmul: because RowParallelLinear is split along its input, each rank’s slab is exactly the input slice that rank needs, so no data has to move between the two matmuls. Each rank multiplies its shard against its slice and produces a partial result — a same-shaped tensor that is correct except that it is missing every other rank’s contribution. To finish, you add the partials together elementwise across ranks. That summation is the all-reduce: a collective in which every rank contributes its tensor and every rank receives the sum. It happens with a single collective:

# vllm/model_executor/layers/linear.py
if self.reduce_results and self.tp_size > 1:
    output = tensor_model_parallel_all_reduce(output_parallel)
else:
    output = output_parallel

Source: vllm/model_executor/layers/linear.py

That tensor_model_parallel_all_reduce is a thin shim over the TP group’s communicator (vllm/distributed/communication_op.py just calls get_tp_group().all_reduce(...)). The diagram below traces a single column-then-row pair on two ranks, showing where the data stays sharded and where the one all-reduce stitches it back together.

flowchart TD
    X["input activation (replicated on every rank)"]
    X --> C0["GPU 0: ColumnParallelLinear, left half of weight"]
    X --> C1["GPU 1: ColumnParallelLinear, right half of weight"]
    C0 --> R0["GPU 0: RowParallelLinear on its shard, partial output"]
    C1 --> R1["GPU 1: RowParallelLinear on its shard, partial output"]
    R0 --> AR["all-reduce: sum partials across ranks"]
    R1 --> AR
    AR --> Y["full output (now replicated again)"]

In a transformer block, this happens twice — once after attention’s output projection, once after the MLP’s down-projection — so a TP-sharded forward pass costs two all-reduces per layer per token. That is the defining property of tensor parallelism: it is exact and load-balanced (every rank does identical work), but it injects a blocking collective on the critical path of every single layer. Double the TP degree and you double those collectives. This is why TP wants NVLink and rarely scales gracefully past one node: the all-reduce latency, not the compute, becomes the wall. Even the vocabulary embedding is sharded this way — vocab_parallel_embedding.py splits the token table across ranks and all-reduces the lookups — so that no single GPU holds a full copy of anything.

A subtle TP detail that bites people: bias. Because the partial results get summed, adding a bias on every rank would add it tp_size times. vLLM fuses the bias only on rank 0:

# vllm/model_executor/layers/linear.py
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias

Source: vllm/model_executor/layers/linear.py

Pipeline parallelism: split by layer, pay a bubble

Pipeline parallelism cuts the other way. Instead of slicing each layer across GPUs, you give whole contiguous blocks of layers to each GPU: stage 0 holds layers 0–N, stage 1 holds N–2N, and a request’s activations flow from one stage to the next. The communication is tiny by comparison — a single hidden-state tensor handed across the stage boundary — and vLLM exposes it as a plain point-to-point send and receive on the PP group:

# vllm/distributed/parallel_state.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a blocking way"""

Source: vllm/distributed/parallel_state.py

So PP’s communication currency is cheap, which is exactly why it survives across slow inter-node links where TP would choke. Its cost is structural instead of bandwidth-bound: the bubble. The bubble is idle GPU time that comes purely from the shape of the dependency, not from any slow operation. Because each stage can only start once the stage before it has handed over its activations, a single batch threads through the stages one at a time: while stage 0 works on a batch, stages $1$ through $P-1$ sit idle waiting for its output; when stage 0 finishes and passes the baton, it goes idle, since there is nothing new behind that batch to feed it. With one batch in flight the pipeline is mostly empty — at any instant only one of the $P$ stages is busy.

The standard fix from training — Megatron again — is to keep several microbatches in flight so a stage always has something to chew on: as soon as stage 0 hands batch 1 forward, it picks up batch 2, and the stages fill up like a factory line. vLLM applies the same idea at serving time by keeping several batches in flight through the pipeline at once — a batch queue sized to the PP degree, since “PP requires PP-size concurrent batches to fill the pipeline” (max_concurrent_batches in vllm/config/vllm.py, drained by step_with_batch_queue in vllm/v1/engine/core.py). The sequence below shows two stages first running one batch (idle gaps everywhere) and then overlapping two batches so each stage stays busy.

sequenceDiagram
    participant S0 as Stage 0 layers 0..N
    participant S1 as Stage 1 layers N..2N
    Note over S0,S1: one batch in flight, note the bubble
    S0->>S0: compute batch A
    S0->>S1: send activations of A
    Note over S0: S0 idle, the bubble
    S1->>S1: compute batch A
    Note over S0,S1: two batches in flight, bubble shrinks
    S0->>S0: compute batch A
    S0->>S1: send A
    S1->>S1: compute A
    S0->>S0: compute batch B
    S0->>S1: send B
    S1->>S1: compute B

But it never fully hides the bubble at the fill and drain edges — at the very start only stage 0 has work, and at the very end only the last stage does. The fill-and-drain idle fraction of a $P$-stage pipeline fed $m$ microbatches is $(P-1)/(m+P-1)$: with a single batch in flight ($m=1$) it collapses to $(P-1)/P$, which is the “only one of the $P$ stages is busy” worst case, and it shrinks toward zero only as $m$ grows well past $P$. The curve below shows that bubble fraction against pipeline depth for a few microbatch counts.

Illustrative: computed from the ideal $(P-1)/(m+P-1)$ bubble formula, which ignores per-stage compute imbalance, communication time, and decode’s one-token-per-step starvation; real serving bubbles are larger.

Decode makes this worse: a decode step is a single token per sequence, so there is precious little work to overlap. PP buys you capacity for weights that will not otherwise fit, at the price of latency you can only partially reclaim. TP and PP are routinely combined — TP within a box, PP across boxes — which is exactly what the ExternalDP x DP x PP x TP layout is built to express.

Expert parallelism: split by expert, pay an all-to-all

The third axis only exists for mixture-of-experts (MoE) models, and it is where the chapter’s real subject begins. In a dense transformer, every token passes through the same single MLP. In an MoE layer that one MLP is replaced by many independent expert MLPs — separate sets of weights, often dozens or hundreds of them — plus a small router network that looks at each token and picks a few experts (the top-k, often just one or two) to actually run for that token. Every other expert is skipped for that token. Switch Transformers (arXiv:2101.03961) showed you can scale parameter count enormously this way while keeping per-token compute roughly fixed, because each token only ever touches a handful of experts: you can add experts to grow the model’s capacity without making any single token’s forward pass more expensive. The natural way to serve such a model is expert parallelism: put different experts on different GPUs. vLLM builds the EP group only when the model actually has experts —

# vllm/distributed/parallel_state.py
# Don't create EP group for dense models.
if config.model_config is None or config.model_config.is_moe:

Source: vllm/distributed/parallel_state.py

— and the placement of experts onto ranks is computed by determine_expert_map, which by default spreads them linearly and returns a global-to-local map with -1 marking experts this rank does not own (so a rank can quickly tell “not mine, this token’s expert lives elsewhere”). The communication pattern is the giveaway. A token’s chosen experts can live on any GPU, not the one the token currently sits on, so the layer must route every token to wherever its chosen experts are, run them, then route the results back. The forward trip is the dispatch: each rank sends every token to the rank that owns the expert that token picked. After the experts compute, the combine sends each result back to the rank the token came from. Each of those is an all-to-all collective — a pattern where, in the general case, every rank sends a (possibly different) chunk of data to every other rank simultaneously. The diagram below traces one token’s round trip through dispatch and combine.

flowchart LR
    T["token on GPU 0, router picks expert 11"]
    T -->|dispatch all-to-all| E["GPU 1 runs expert 11"]
    E -->|combine all-to-all| B["result returns to GPU 0"]
    B --> N["token continues to next layer on GPU 0"]

vLLM has a whole family of backends for these two collectives (vllm/distributed/device_communicators/all2all.py, with the choice exposed as all2all_backend in vllm/config/parallel.py, defaulting to an allgather/reduce-scatter scheme). All-to-all is the most expensive collective shape there is, because every rank talks to every other rank, but the saving grace is that each token only carries the experts it actually selected, so the volume of data on the wire is small even though the fan-out is total.

And here is the catch that the rest of the chapter is about. An all-to-all is a synchronizing step: it cannot finish until every rank has both sent and received its share, so the whole collective is gated by the slowest participant. And the slowest participant is whichever rank has the most expert work to do. If the router sends 40% of all tokens to experts that happen to live on GPU 3, then GPU 3 is doing 40% of the expert compute while the other GPUs idle, and every all-to-all stalls until GPU 3 finishes. Expert parallelism is only as fast as its most overloaded expert. Routers, trained on data that does not look like your production traffic, are reliably not uniform — a few experts become popular, others go nearly cold — and so a freshly deployed MoE model often runs at a fraction of its theoretical throughput purely from load skew. The DeepSeek-V3 technical report (arXiv:2412.19437) documents this at scale and describes the production answer: replicate the hot experts and actively rebalance which GPU hosts which. That answer is the EPLB subsystem.

Expert-parallel load balancing: rebalancing under live traffic

The vocabulary first, because EPLB’s whole design rests on it. The docstring of eplb_state.py lays it out:

# vllm/distributed/eplb/eplb_state.py
# - **Logical Expert**: An expert that is part of the model's logical structure.
# - **Redundant Expert**: To achieve load balancing, for some popular logical
#   experts, we create additional copies of the expert weights.
# - **Physical Expert**: An expert that is instantiated on a specific device.

The distinction is the whole trick, so make it concrete. A logical expert is one of the experts the model architecture defines — there are a fixed number of them, say 256, and the router only ever picks logical experts. A physical expert is an actual instantiated copy of an expert’s weights sitting on a specific GPU. Normally there is one physical copy per logical expert. EPLB breaks that one-to-one assumption: it deploys more physical slots than there are logical experts (the extras are the redundant experts), so DeepSeek-R1 with 32 redundant experts becomes 288 physical experts across the cluster. The freedom this buys is that a single hot logical expert can be backed by several physical copies on several GPUs. The router still picks the logical expert; the runtime then spreads the tokens that chose it across its physical copies, so its load is divided instead of piling onto one device. The diagram below shows a hot expert (logical 11) given two physical copies while cold experts keep one.

flowchart TD
    L11["logical expert 11 (hot, 40% of tokens)"]
    L42["logical expert 42 (cold)"]
    L11 --> P0["physical slot on GPU 0"]
    L11 --> P1["physical slot on GPU 1"]
    L42 --> P2["physical slot on GPU 2"]
    P0 --> G0["GPU 0 now carries ~20%"]
    P1 --> G1["GPU 1 now carries ~20%"]
    P2 --> G2["GPU 2 carries its small share"]

The job of EPLB is to decide, from observed traffic, the mapping from physical slots to logical experts, and to keep that mapping current as traffic drifts.

EPLB runs as a loop with three phases: observe the load every step, periodically decide a better mapping, then apply it by moving weights. The observation half runs every step. The engine records how many tokens each physical expert served into a sliding window (a count over the most recent N steps, so old traffic ages out), and periodically logs how lopsided things are — the step method computes a balancedness ratio, the mean load over the max load across ranks,

$$\text{balancedness} = \frac{\text{avg tokens}}{\text{max tokens}}$$

which is exactly the “busiest rank sets the pace” intuition turned into a number you can alert on:

# vllm/distributed/eplb/eplb_state.py
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

Source: vllm/distributed/eplb/eplb_state.py

A balancedness of 1.0 is perfect; 0.4 means your worst GPU is carrying more than double its fair share and your effective throughput is roughly that fraction of peak. The bars below make the earlier “40% of tokens land on GPU 3” example concrete on an 8-GPU cluster: skewed, the busiest GPU holds 40% against a 12.5% fair share, so balancedness is $12.5/40 \approx 0.31$ and the all-to-all runs at roughly a third of peak; after EPLB replicates the hot expert across more physical slots, the per-GPU load flattens back toward that 12.5% line and balancedness climbs near 1.0.

Illustrative: an 8-GPU sketch of the chapter’s 40%-on-one-GPU example; real per-rank loads come from the observed token counts EPLB records, and the post-rebalance distribution depends on how many redundant slots the hot expert receives.

The window size and how often to act on it are tunable — window_size defaults to 1000 steps and step_interval to 3000 in EPLBConfig (vllm/config/parallel.py) — so EPLB reacts to sustained skew, not to momentary noise. The flowchart below traces the full loop, including the asynchronous apply path detailed below.

flowchart TD
    A["every step: record per-expert token counts in sliding window"]
    A --> B{"step_interval reached?"}
    B -->|no| A
    B -->|yes| C["run placement: replicate hot experts, pack onto GPUs"]
    C --> D["stage new expert weights into off-critical-path buffer (inference keeps running on old weights)"]
    D --> E{"all ranks confirm transfer done?"}
    E -->|no, keep serving and wait| E
    E -->|yes| F["atomic swap: move_from_buffer to live weights, all ranks flip together"]
    F --> A

When the rearrangement step fires, the recorded loads feed a placement algorithm in vllm/distributed/eplb/policy/default.py, adapted from DeepSeek’s open-source EPLB. It is a two-part packing problem: first decide how many copies each expert gets, then decide which GPU each copy lives on. The first part, replication, hands the redundant physical slots to the heaviest logical experts, greedily. It repeatedly looks at the load each expert would carry if its current copies split its tokens evenly — weight / logcnt, the expert’s total load divided by its replica count — and gives the next spare copy to whichever expert is worst off by that measure, then updates that expert’s replica count and repeats. An expert pulling 40% of traffic with one copy looks far worse than one pulling 5%, so it gets the first extra copy; after that its per-copy load halves and some other expert may become the worst, so the next copy may go elsewhere —

# vllm/distributed/eplb/policy/default.py
for i in range(num_log, num_phy):
    redundant_indices = np.argmax(weight / logcnt, axis=-1)
    phy2log[:, i] = redundant_indices
    logcnt[arangen, redundant_indices] += 1

— and second, pack those physical experts onto GPUs so the per-GPU totals come out even. This is bin packing: each GPU is a bin with a load budget, each physical expert is an item weighing its share of the traffic, and balanced_packing places items so no bin ends up far heavier than the others (lightest-bin-first). The policy is hierarchical: it first packs expert groups onto nodes, then replicates within a node, then packs onto GPUs within a node, because intra-node NVLink is cheap and cross-node traffic is not — the same topology awareness that drove the rank layout at the top of the chapter. Keeping a hot expert’s copies on the same node means its share of every all-to-all stays on the fast intra-node links. One detail reveals how much this is built for live reconfiguration rather than restart: preserve_intragpu_slots reorders the new mapping so that any expert staying on a given GPU keeps its old slot, which means its weights never have to move:

# vllm/distributed/eplb/policy/default.py
# Reorder the new mapping per GPU so that experts that remain on the same GPU
# keep their previous slot positions when possible.

Source: vllm/distributed/eplb/policy/default.py

That matters because actually applying a new mapping means shipping expert weights between GPUs while the model is serving requests. rearrange_expert_weights_inplace in rebalance_execute.py walks the layers and, for each, copies the experts that need to move into a pre-allocated buffer and exchanges them over the EP group’s process group. Note that EPLB uses its own communicator, deliberately separated from the forward-pass collectives:

# vllm/distributed/parallel_state.py
# This is a separate process group to isolate EPLB communications
# from MoE forward pass collectives and prevent deadlocks when
# using torch.distributed in execution with torch.distributed in EPLB.

Source: vllm/distributed/parallel_state.py

The most production-critical refinement is that this can run asynchronously. A synchronous rearrangement would freeze every GPU mid-traffic while terabytes of expert weights shuffle around — a latency spike no SLO survives. The async path (vllm/distributed/eplb/async_worker.py, driven by the is_async branches in step) copies the new weights into a separate buffer while normal inference keeps running on the old weights, then performs a fast in-place swap only once every transfer is finished, via move_from_buffer. The swap must be all-or-nothing across the whole cluster, and that is why the step method checks _all_ranks_result_ready before committing: if some ranks switched to the new physical-to-logical mapping while others were still on the old one, the dispatch all-to-all would send tokens to a GPU that no longer hosts the expert they expect, corrupting the routing. Waiting for every rank to confirm guarantees all of them flip on the same step. The same machinery generalizes to changing the number of GPUs entirely: vllm/distributed/elastic_ep/ reuses the rank-remapping logic (the rank_mapping argument threaded through rearrange_expert_weights_inplace) to scale an EP deployment up or down, while vllm/model_executor/layers/fused_moe/eep_reconfigure.py rebuilds the fused-MoE expert kernels for the new EP size — together redistributing experts onto a changed device set without a full restart.

What is still hard

None of this is settled engineering. Picking parallel degrees — how much TP versus PP versus EP for a given model, cluster, and SLO — is still mostly empirical; the cost models are crude and the search space is large. EPLB rebalances on past load, so a sharp traffic shift outruns it until the next window closes, and choosing window_size/step_interval trades responsiveness against thrashing the network with weight shuffles. The replication count (num_redundant_experts) spends HBM you could have given to KV cache, re-opening the Chapter 4 tradeoff at cluster scale: every redundant copy is blocks you are not caching prefixes in. And all of this multiplies the failure surface — a single dead GPU now takes down a TP shard, a pipeline stage, or an expert host, each with a different blast radius. Fault tolerance for inference at this scale is genuinely unsolved; most systems today simply restart the affected replica.

What we have not split is the KV cache itself. We sharded weights every way there is, but a request’s KV cache still lives wholly within one replica, and we have been assuming it fits. The next chapter breaks that assumption — when the prefix you want has been evicted from GPU memory, you spill cold KV blocks to CPU, disk, or an external store through a connector abstraction. That same connector, stretched across a network, is what finally lets prefill and decode run on separate GPU pools two chapters from now.

Further reading

  • Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism — arXiv:1909.08053 — the origin of the column/row-parallel pairing and the one-all-reduce-per-block trick vLLM’s linear.py implements; read it for why TP communication is structured the way it is.
  • Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity — arXiv:2101.03961 — introduces expert parallelism and top-k routing; read it for the intuition that parameter count and per-token compute can be decoupled.
  • DeepSeek-V3 Technical Report — arXiv:2412.19437 — documents large-scale MoE serving and the redundant-expert / load-balancing practice that vLLM’s EPLB directly adapts; read it for what expert skew looks like at production scale.

When the cache overflows: KV offloading and external KV stores

In Chapter 7 the GPU’s free KV pool turned into a content-addressed cache: identical prefixes get hashed, deduplicated, and reused, and prefix reuse became the single largest TTFT lever in real traffic. But that cache has a hard ceiling. The block pool lives in HBM, and HBM is the most expensive, scarcest resource in the box. When a popular system prompt has been pushed out by other traffic, the next request that wants it pays full prefill again, even though the exact KV bytes existed a few seconds ago. Chapter 15 then spent the model across many GPUs, which adds aggregate HBM but does nothing to change the eviction problem on any single device. The cache still overflows; it just overflows on more GPUs at once.

This chapter is about giving the prefix cache somewhere to fall back to. CPU DRAM (the host’s main memory, reached over the PCIe bus) is roughly an order of magnitude cheaper per byte than HBM (the GPU’s on-package high-bandwidth memory, where the block pool lives), and a host has a lot more of it. Local NVMe (a solid-state disk on the same machine) is cheaper still, and a networked store (memory or disk on another machine, reached over the network) cheaper again. The KV blocks you evicted from the GPU are not garbage; they are cold cache lines. If you can spill them to a cheaper tier and pull them back faster than you could recompute them, you have extended the effective prefix cache far beyond device capacity. The catch is the “faster than recompute” part, and the abstraction vLLM uses to manage the spilling is, deliberately, the same connector API that the next chapter will stretch across a network to move KV between separate machines.

These tiers form a classic memory hierarchy: each step down is cheaper and larger but slower to reach. The diagram below shows where KV blocks can live and roughly how fast each tier answers a request for them. Think of HBM as the working set, DRAM as the overflow, and the slower tiers as deep archive that only pays off for prefixes long and popular enough to amortize the trip.

flowchart TD
    GPU["GPU HBM block pool: fastest, smallest, most expensive"]
    DRAM["CPU DRAM (over PCIe): ~10x cheaper, much larger"]
    NVME["Local NVMe disk: cheaper still, larger"]
    NET["Networked / external store: cheapest, largest, shared across replicas"]
    GPU -->|"evict / spill (store)"| DRAM
    DRAM -->|"fetch on prefix hit (load)"| GPU
    DRAM -->|"spill"| NVME
    NVME -->|"load"| DRAM
    NVME -->|"spill"| NET
    NET -->|"load"| NVME

That diagram shows where blocks can live and how they move, but not how far apart the tiers really are. The cost of the gap is the whole reason offloading is a bet rather than a free win: each step down is roughly an order of magnitude slower to reach. The chart below shows representative round-trip latencies to fetch one KV block from each tier, on a log scale so the roughly-10x-per-step spacing is visible at a glance.

Illustrative orders of magnitude: the shape (about 10x slower per step down) is the point, not the exact figures, which vary with bus generation, block size, and contention.

The economics: when is a load cheaper than a recompute?

Decode is memory-bound and prefill is compute-bound; that asymmetry from Chapter 3 is exactly what makes offloading viable. Recomputing a prefix means running prefill over every cached token: a compute-bound pass whose cost grows with prefix length. Loading the same prefix from CPU means a PCIe transfer whose cost is the KV byte count divided by bus bandwidth. For a prefix of $n$ tokens the KV payload follows directly from the model shape:

$$\text{bytes} = n \cdot 2 \cdot L \cdot H \cdot d_{\text{head}} \cdot b$$

where $L$ is the number of layers, $H$ the number of KV heads, $d_{\text{head}}$ the per-head dimension, $b$ the bytes per element, and the factor $2$ counts the separate key and value tensors. The load time is then $t_{\text{load}} = \text{bytes} / \text{BW}{\text{bus}}$, while recompute time $t{\text{recompute}}$ rises with $n$ on the compute engine. The decision is therefore a direct comparison of two times. Recompute cost rises with prefix length on the compute engine; load cost rises with prefix length on the bus. Because moving already-computed bytes is far cheaper per token than re-deriving them, the two curves cross: below some prefix length recompute wins, above it the load wins. For a long shared prefix the transfer can win by a wide margin, because you are moving bytes you already paid to compute instead of re-deriving them. For a short prefix it can lose, because the fixed overheads of staging a transfer (setting up the transfer descriptor that tells the copy engine which bytes to move, synchronizing CUDA streams, and copying a partial first block) swamp the few hundred microseconds prefill would have taken anyway.

The crossover below makes this concrete. Recompute cost is essentially linear in prefix length with no fixed startup, so it is a straight line through the origin. Load cost is also linear in length but with a cheaper per-token slope (you are moving bytes, not re-deriving them) plus a fixed staging overhead (transfer-descriptor setup, stream sync, partial first block), so it starts high and climbs slowly. Below the crossing point the fixed overhead makes the load lose; above it the cheaper slope makes the load win by a widening margin.

Illustrative: recompute = 8 us/token, load = 300 us fixed + 5 us/token. Real slopes and the fixed overhead depend on model shape, GPU, and bus bandwidth, but the qualitative crossover (short prefix favors recompute, long prefix favors load) is the durable point.

So offloading is not a free win; it is a bet that the prefix is both long enough and reused often enough to amortize the round trip. vLLM encodes this skepticism directly. The CPU spec defaults to offloading prompt blocks only, skipping decode-phase KV (the KV produced one token at a time during generation, which is unlikely to be a shared prefix for any future request), and it can require a block to be requested more than once before it earns a slot in the cache. The store threshold and the prompt-only default both exist because indiscriminately offloading everything just burns PCIe bandwidth on blocks no one will ask for again.

The decision flow below traces how the engine resolves a single waiting request against this tiered cache. It checks the fastest tier first and falls through to the next only when a tier misses; prefill then runs over whatever tokens no tier could supply.

flowchart TD
    REQ["new request: hash prompt into block keys"] --> GPUHIT{"blocks already in GPU HBM?"}
    GPUHIT -->|"yes"| REUSE["reuse in place, no transfer"]
    GPUHIT -->|"no"| OFFHIT{"blocks found in offload tier? (lookup)"}
    OFFHIT -->|"no"| RECOMP["recompute: run prefill over the missing tokens"]
    OFFHIT -->|"yes"| WORTH{"long and reused enough to beat recompute?"}
    WORTH -->|"no"| RECOMP
    WORTH -->|"yes"| LOAD["prepare_load: pin blocks, transfer DRAM to HBM"]
    REUSE --> SCHED["schedule remaining tokens for prefill / decode"]
    LOAD --> SCHED
    RECOMP --> SCHED

The OffloadingManager: a second-tier block allocator in the scheduler

The scheduler-side brain of offloading is the OffloadingManager, and the docstrings on its primitives lay out a vocabulary that should feel familiar from the GPU block pool, only one level down the hierarchy.

# materials/.../vllm/v1/kv_offload/base.py
class OffloadingManager(ABC):
    @abstractmethod
    def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
        """Checks whether a single block is offloaded and ready to be read."""
    @abstractmethod
    def prepare_load(self, keys, req_context) -> LoadStoreSpec:
        """Prepare the given blocks to be read.
        The given blocks will be protected from eviction..."""
    @abstractmethod
    def prepare_store(self, keys, req_context) -> PrepareStoreOutput | None:
        """Prepare the given blocks to be written."""

Source: vllm/v1/kv_offload/base.py

Read these three methods as the verbs of a second allocator. lookup asks whether a cold block is present in the offload tier (the answer feeds the “blocks found in offload tier?” branch in the diagram above). prepare_load pins blocks so they cannot be evicted mid-transfer (pinning means marking them in-use so the eviction policy skips them) and hands back a LoadStoreSpec, a small record describing where the bytes live so the worker knows what to copy. prepare_store reserves space for blocks on their way out of the GPU, and crucially returns the list of other blocks it had to evict to make room, so the caller learns exactly what was displaced. This is an allocator with its own eviction policy, sitting behind the GPU’s allocator, addressed by the same content hashes Chapter 7 computed for prefix caching. The key is just the block hash packed with its KV-cache-group index (the index identifying which group of attention layers a block belongs to, since a model can have several such groups):

# materials/.../vllm/v1/kv_offload/base.py
def make_offload_key(block_hash: bytes, group_idx: int) -> OffloadKey:
    """Pack a block hash and group index into an `OffloadKey`."""
    return OffloadKey(block_hash + group_idx.to_bytes(4, "big", signed=False))

Source: vllm/v1/kv_offload/base.py

Because the offload key derives from the same parent-chained block hash the GPU prefix cache uses, a block evicted from HBM and a block resident in DRAM share an identity. A later request hashing its prompt will produce the same keys, and a lookup will find them. The two caches are tiers of one logical cache, not two separate caches that happen to hold similar data.

The CPU manager is, under the hood, exactly a paged allocator with a pluggable eviction policy:

# materials/.../vllm/v1/kv_offload/cpu/manager.py
_CACHE_POLICIES: dict[str, type[CachePolicy]] = {
    "lru": LRUCachePolicy,
    "arc": ARCCachePolicy,
}

Source: vllm/v1/kv_offload/cpu/manager.py

LRU is the default; ARC (adaptive replacement) is offered for workloads where pure recency misbehaves. Either way the manager owns ref-counting, a free list, and event emission, and delegates only the “which block dies next” decision to the policy. If you have built a buffer cache, none of this is new; what is new is that the cache lines are KV blocks keyed by prompt content.

The connector API: the seam this chapter introduces

The manager decides what to move. Something else has to actually move bytes between HBM and DRAM during the forward pass, and that something is the connector. KVConnectorBase_V1 is the abstraction this chapter introduces and the next chapter reuses wholesale. It is split cleanly into scheduler-side and worker-side halves, and its scheduler-side entry point is the hook that lets external KV participate in admission:

# materials/.../vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def get_num_new_matched_tokens(
    self, request: "Request", num_computed_tokens: int,
) -> tuple[int | None, bool]:
    """Get number of new tokens that can be loaded from the
    external KV cache beyond the num_computed_tokens."""

Source: vllm/distributed/kv_transfer/kv_connector/v1/base.py

The scheduler calls this for every waiting request, right after it has computed the GPU prefix-cache hit, and folds the answer into how much prefill it must actually schedule:

# materials/.../vllm/v1/core/sched/scheduler.py
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
    ext_tokens, load_kv_async = (
        self.connector.get_num_new_matched_tokens(
            request, num_new_local_computed_tokens
        )
    )

Source: vllm/v1/core/sched/scheduler.py

This is the heart of the design. The GPU prefix cache answers “how many tokens are already in HBM”; the connector answers “how many more are recoverable from a cheaper tier.” Prefill only runs over what neither tier could supply. The return signature carries a second subtlety: a None first element means “ask me again later,” which lets a slow backend kick off an asynchronous lookup without blocking the scheduler step (the scheduler moves on and re-polls the request on a later step), and a boolean second element distinguishes synchronous loads (the transfer finishes before the next step, so the request is ready immediately) from asynchronous ones (the request waits across several steps for its KV to arrive over a slow link). The asynchronous path is where this machinery starts to look like disaggregation, and it is the seam Chapter 17 widens.

The worker-side half is built around the forward pass. Its methods bracket model execution:

# materials/.../vllm/distributed/kv_transfer/kv_connector/v1/base.py
@abstractmethod
def start_load_kv(self, forward_context, **kwargs) -> None:
    """Start loading the KV cache from the connector to vLLM's paged
    KV buffer... before the forward pass to enable async loading..."""
@abstractmethod
def wait_for_save(self):
    """Block until all the save operations is done..."""

Source: vllm/distributed/kv_transfer/kv_connector/v1/base.py

The reason these methods bracket the forward pass rather than running before or after it is overlap. A GPU transfer and a GPU compute can proceed at the same time on separate streams, so the connector deliberately starts loads just before the layers compute and pushes saves out as layers finish, hiding the transfer behind work the GPU was going to do anyway. start_load_kv fires before the layers run so a load can overlap compute; save_kv_layer and wait_for_save push freshly computed KV out without stalling the GPU. The model runner wraps all of this in a context manager so every execute_model enters and exits the connector lifecycle uniformly:

# materials/.../vllm/v1/worker/kv_connector_model_runner_mixin.py
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
kv_connector.start_load_kv(get_forward_context())
try:
    yield output
finally:
    if wait_for_save and not defer_finalize:
        kv_connector.wait_for_save()
    output.finished_sending, output.finished_recving = (
        kv_connector.get_finished(scheduler_output.finished_req_ids)
    )

Source: vllm/v1/worker/kv_connector_model_runner_mixin.py

Note kv_connector_no_forward: the engine runs the connector even on steps with no tokens to compute, because a request blocked on an asynchronous remote load still needs its transfer driven forward. The connector is not a passenger on the forward pass; it has its own work to do whether or not the model runs.

The sequence diagram below traces one engine step from the scheduler’s admission check through the worker’s bracketed forward pass, showing how a load is launched before compute and a store is deferred until after it. The scheduler-side hook and the worker-side lifecycle are two halves of the same connector, talking through the metadata the model runner binds at the top of every step.

sequenceDiagram
    participant S as Scheduler
    participant C as Connector scheduler side
    participant W as Connector worker side
    participant G as GPU forward pass
    S->>C: get_num_new_matched_tokens(request)
    C-->>S: extra tokens recoverable from offload tier
    S->>S: schedule prefill only over the remainder
    S->>W: bind connector metadata for this step
    W->>W: start_load_kv begins DRAM-to-HBM copy
    W->>G: run layers, compute overlaps the load
    G-->>W: fresh KV produced
    W->>W: wait_for_save drains last step's deferred stores
    W->>W: queue this step's new stores for next step
    W-->>S: report finished sending/receiving ids

Moving the bytes: streams, copy engines, and partial blocks

When OffloadingConnector actually transfers a block it delegates to a worker that drives async copies on dedicated CUDA streams, and the comments in the copy handler are where the hardware reality shows through:

# materials/.../vllm/v1/kv_offload/cpu/gpu_worker.py
def _select_swap_blocks_fn(kv_cache_groups_data_refs, gpu_to_cpu):
    """Resolve the swap_blocks function for a handler at init time."""
    # GPU->CPU is bandwidth-bound; the dedicated copy engine beats Triton.
    if gpu_to_cpu:
        return ops.swap_blocks_batch

Source: vllm/v1/kv_offload/cpu/gpu_worker.py

GPU-to-CPU stores route to the dedicated DMA copy engine (a piece of hardware whose only job is moving bytes across the bus) because the transfer is purely bandwidth-bound and the copy engine runs it without occupying SMs, the streaming multiprocessors that do the model’s math. Spending compute units to shuffle bytes would steal them from the next forward pass for no benefit. Small CPU-to-GPU loads, by contrast, can win with a Triton kernel (a small GPU program) that gathers many tiny page copies in one launch, and only when the payloads are small and 8-byte aligned, because at that size the per-transfer launch overhead dominates and batching the copies into one kernel beats issuing many separate DMA descriptors. This is the same memory-vs-compute reasoning as everywhere else in the book, applied to the interconnect: match the transfer to the hardware unit that runs it most cheaply.

Two ordering details matter for correctness. Stores must wait for the model to finish writing the KV they are reading, and loads from pinned host memory can let the driver reorder source reads because nothing on the GPU is concurrently writing them:

# materials/.../vllm/v1/kv_offload/cpu/gpu_worker.py
if self.gpu_to_cpu:
    # wait for model computation to finish before offloading
    stream.wait_stream(current_platform.current_stream())

Source: vllm/v1/kv_offload/cpu/gpu_worker.py

And stores are deliberately deferred to the start of the next engine step, so that offloading bandwidth does not contend with the latency-critical transfers around token sampling:

# materials/.../vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
# NOTE(orozery): defer the store to the beginning of the next
# engine step, so that offloading starts AFTER transfers related
# to token sampling, thereby avoiding delays to token generation.
self._unsubmitted_store_jobs.append((job_id, entry.transfer_spec))

Source: vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py

This is the kind of detail that separates a working offload path from a fast one: a naive implementation that flushes stores synchronously inside the step adds DRAM-write latency to every decode, exactly where the ITL budget is thinnest.

There is also a granularity mismatch the worker has to absorb. Offloaded blocks are often coarser than GPU blocks: one offloaded block can hold several GPU blocks, a ratio $r = \text{block_size}{\text{offload}} / \text{block_size}{\text{GPU}}$ set by block_size in the connector config, so the slow tier stores fewer, larger pages. That keeps its index small and its transfers efficient, but it means a load no longer lines up cleanly with the GPU’s finer blocks. When a request’s prefix hit does not start exactly on one of those coarse boundaries, say the request shares the second half of a coarse block but not the first, the leading offloaded block is only partially relevant. The worker therefore carries block_indices to know how many sub-blocks to skip at the front of that first block so it copies only the bytes the request actually needs. The GPULoadStoreSpec docstring spells out precisely why this bookkeeping exists, and it is the price of making the offload tier’s page size independent of the GPU’s 16-token blocks.

Eviction, races, and the unglamorous correctness work

The genuinely hard part of offloading is not the happy path; it is keeping the two allocators consistent while requests preempt, abort, and reuse blocks underneath in-flight transfers. The danger is a use-after-free in slow motion: a store is still reading a GPU block to copy it down to DRAM when the scheduler, seeing that block as free, hands it to a new request that overwrites it; the store then captures garbage. A fence is the fix, code that forces the in-flight transfer to finish (a flush) before the block can be reused. The scheduler-side connector is full of these. A store that is still draining to DRAM holds claims on GPU blocks the scheduler might want to reallocate, so the scheduler tracks which pending jobs touch which block ID and forces a flush before any tracked block is handed to a new request:

# materials/.../vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
# Flush jobs that contain re-allocated blocks.
if (
    self._block_id_to_pending_jobs
    and not self._block_id_to_pending_jobs.keys().isdisjoint(
        self._current_batch_allocated_block_ids
    )
):
    self._current_batch_jobs_to_flush.update(...)

Source: vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py

Preempted requests get their pending stores flushed too, because preemption (the load-shedding valve from Chapter 6) frees blocks that an outstanding store may still be reading. Sliding-window models add another wrinkle: their blocks can be dropped before the request even finishes, so those blocks are watched from the moment the store is created rather than at request completion. None of this is exotic distributed-systems theory; it is the careful reference-counting that any cache with asynchronous write-back and a shared backing store has to get exactly right, and it is worth reading the scheduler in full to appreciate how much of the file is fences rather than transfers.

Tiers below CPU, and pluggable external stores

CPU DRAM is the first tier, not the last. The TieringOffloadingSpec makes CPU a primary tier with configurable secondaries:

# materials/.../vllm/v1/kv_offload/tiering/spec.py
"secondary_tiers": [
    {"type": "example", "custom_param": 67}
]

Source: vllm/v1/kv_offload/tiering/spec.py

A FileMapper hashes blocks to filenames for an NVMe tier (the block hash becomes the file path, so the same content always maps to the same file); object-storage and networked tiers slot in behind the same OffloadingManager interface. The contract never changes: lookup, prepare_load, prepare_store. The tier just answers more slowly and holds more. This is the payoff of having defined the offload tier as an allocator with those three verbs: a new tier only has to implement the verbs, and the scheduler and worker above it are none the wiser.

Above the in-tree offloading machinery sits a second integration path: external KV stores that bring their own caching brain. The connector registry makes these first-class. LMCacheConnectorV1 wraps the LMCache engine, delegating every connector method straight through, and even advertises when it needs piecewise CUDA graphs because its layerwise loads cannot be captured (a callback to Chapter 10’s capture-vs-eager split):

# materials/.../vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config) -> bool:
    return extra_config.get("use_layerwise", False)

Source: vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py

The Mooncake store is another such backend living in the same connector tree (mooncake/store/). What these external stores buy you is a KV cache that is shared across replicas rather than private to one process, which is precisely the bridge from this chapter’s single-replica offloading to the fleet-level cache-aware routing of Chapter 18, and to disaggregation, where producer and consumer are different machines entirely.

What is still unsolved

Offloading does not make the recompute-versus-load decision for you; it gives you the mechanism and leaves the policy underspecified. The store threshold and prompt-only defaults are blunt heuristics, and the genuinely right answer depends on prefix length distribution, reuse frequency, and live PCIe contention, none of which the engine measures end to end. Compression is the other open frontier: KV blocks can in principle be quantized or entropy-coded before they hit the slow tier, trading a little decode-time fidelity for a lot more effective cache and cheaper transfers, but in-tree vLLM offloads raw bytes and the integration surface for compressed KV is still maturing. And every tier you add lengthens the tail: a prefix that lives on NVMe behind a busy queue can be slower to fetch than to recompute, which means the lookup needs to be load-aware, not just presence-aware. The API has a None-means-retry escape hatch for slow backends, but choosing not to wait is still a manual call.

The connector you just met is intentionally more general than offloading needs. Its asynchronous-load path, its scheduler hook that admits external tokens, its worker-side load/save lifecycle: all of it works just as well when the “external store” is another GPU’s HBM reached over a network instead of this host’s DRAM reached over PCIe. That is the whole idea of the next chapter. Prefill and decode have opposite hardware appetites, so we will run them on separate GPU pools and ship the KV cache between them, over this same connector API, and the only thing that really changes is that the wire gets longer.

Further reading

  • CacheGen: KV Cache Compression and Streaming for Fast Large Language Model Serving — arXiv:2310.07240 — Compresses KV state into a compact bitstream and streams it to the serving node, attacking the load-versus-recompute tradeoff this chapter leaves open; read it for the compression frontier of offloading.
  • Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving — arXiv:2407.00079 — Treats a tiered, replica-shared KV store as the organizing center of the serving system; read it for where the connector abstraction is heading and as the design behind one of vLLM’s in-tree store backends.
  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180 — The paging idea this chapter extends one tier down; reread it to see the GPU block pool and the offload tier as the same allocator at two levels.

Prefill/decode disaggregation

Every chapter in Part II fought the same quiet war. Chapter 3 named the enemy: a prefill step is compute-bound, saturating the GPU’s tensor cores over the whole prompt at once, while a decode step is memory-bound, dragging the entire model and KV cache through HBM to emit a single token. Chapter 6 then put both kinds of work into one batch and one token budget, and chunked prefill exists precisely because mixing them is awkward. A long prefill, left whole, stalls every decode sharing its step; sliced into chunks, it merely taxes them. Chunked prefill is a truce, not a victory. The two phases still ride the same GPU, still draw from the same token budget, still interfere.

There is a more radical option, and this chapter is about it: stop sharing the GPU at all. Run prefill on one pool of GPUs and decode on another. Let the prefill pool run flat-out at its compute roofline with no decodes to slow it down, let the decode pool run flat-out at its bandwidth roofline with no prefill spikes wrecking its inter-token latency, and ship the KV cache from the first pool to the second over the network. The two pools scale independently, because their bottlenecks are independent. And the wire that carries the KV between them is not a new abstraction at all. It is the connector API from Chapter 16, the same KVConnectorBase_V1 that spilled cold blocks to CPU and disk, now stretched across a NIC.

That last point is the through-line. Offloading and disaggregation are the same machinery pointed at different destinations.

A piece of vocabulary first, because the rest of the chapter leans on it. The prefill instance (the “P” side, or producer) is the GPU pool that reads the whole prompt and computes its KV cache once. The decode instance (the “D” side, or consumer) is the pool that takes that KV cache and generates output tokens one at a time. The KV cache is the per-token key/value tensors that attention needs in order to attend back over everything seen so far; computing it for an N-token prompt is exactly the expensive prefill work, and once computed it is just bytes that can be copied. Disaggregation’s entire bet is that copying those bytes from P to D is cheaper than recomputing them on D. The diagram below shows the shape of the deployment: a request flows through P, its KV cache crosses the network, and D streams tokens back to the user.

flowchart LR
    U["client request"] --> P["prefill pool (producer): compute full KV cache once, compute-bound"]
    P -->|"KV cache over RDMA"| D["decode pool (consumer): emit one token per step, memory-bound"]
    D -->|"output tokens"| U
    P -.->|"scales on prompt-token rate"| AS["autoscaler"]
    D -.->|"scales on output-token rate"| AS

The dashed edges hint at the payoff the next section makes precise: because P is throttled by compute and D by memory bandwidth, the two pools can be sized and scaled by completely different signals.

Why separate at all: the goodput argument

Chunked prefill tames interference but cannot erase it, and it forces a single configuration onto two workloads with opposite needs. Prefill wants large batches and big tensor-parallel groups to feed the compute units; decode wants the smallest TP degree that still fits the model, so that each token’s mandatory walk through HBM is as short as possible. One GPU pool cannot be tuned for both. You pick a compromise and both phases pay for it.

Two 2024 papers made this argument quantitative and gave the field its vocabulary. DistServe (arXiv:2401.09670) disaggregates prefill and decode onto separate GPU pools and optimizes for goodput rather than raw throughput. The distinction is worth pinning down, because it is the whole reason to bother. Throughput counts every token a GPU emits, fast or slow. Goodput counts only the requests that actually met both of their latency targets from Chapter 2: TTFT (time to first token, set by prefill) and TPOT (time per output token, set by decode). A GPU can have wonderful throughput while missing SLOs left and right, if its batches are large but laggy. The insight is that prefill and decode degrade different SLOs when they interfere: a long prefill sharing a step inflates everyone’s TPOT, and a flood of decodes sharing a step inflates everyone’s TTFT. Co-locating them forces a single batching and parallelism configuration onto two phases with opposite optimal points, so you are always sacrificing one SLO to protect the other. DistServe’s finding is that splitting the phases, even after paying to move the KV cache, raises the number of requests served within both SLOs per GPU. Read it for the goodput formulation and the placement search over parallelism degrees. Splitwise (arXiv:2311.18677) reaches the same structural conclusion from a production-fleet angle: it splits the two phases onto separate machine pools, observes that prefill and decode have different power and hardware sweet spots (you can even use different GPU SKUs per pool), and works out KV-transfer and provisioning so the split pays off across a real cluster. Read it for the heterogeneous-hardware and capacity-planning view.

Both papers leave you with the same caveat, and it is the one to hold onto: disaggregation only wins when the KV transfer is cheap relative to the prefill it replaces. Move the bytes too slowly, or move them for a prompt too short to be worth it, and you have added a network hop for nothing. Everything in vLLM’s implementation below is, in one way or another, an attempt to keep that transfer off the critical path.

The shape of a disaggregated deployment

vLLM does not have a “disaggregation mode.” It has a connector and a role flag, and disaggregation is what you get when you point a producer instance and a consumer instance at the same connector. The role lives in KVTransferConfig:

# vllm/config/kv_transfer.py
kv_role: KVRole | None = None
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
are 'kv_producer', 'kv_consumer', and 'kv_both'."""

Source: vllm/config/kv_transfer.py

A prefill instance runs as a producer, a decode instance as a consumer. Each is an ordinary vLLM engine with an ordinary scheduler and paged KV cache; the only difference is that a connector is wired in. That connector is created once, lazily, and stashed as a process-global on the worker side:

# vllm/distributed/kv_transfer/kv_transfer_state.py
if (
    vllm_config.kv_transfer_config.is_kv_transfer_instance
    and _KV_CONNECTOR_AGENT is None
):
    _sync_engine_id_across_tp(vllm_config)

    _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
        config=vllm_config,
        role=KVConnectorRole.WORKER,
        kv_cache_config=kv_cache_config,
    )

Source: vllm/distributed/kv_transfer/kv_transfer_state.py

Note _sync_engine_id_across_tp: every worker in a TP group must agree on the engine id, because that id is how a remote instance addresses this one across the network. Disaggregation is point-to-point between specific engines, not a broadcast.

The connector you choose decides how the bytes actually move, and there are many, not one. The factory registers a whole zoo — NixlConnector, LMCacheConnectorV1, MooncakeConnector, the offloading connectors from Chapter 16, and MultiConnector which composes several at once. NIXL (NVIDIA’s transfer library, point-to-point RDMA) is the reference disaggregation backend and the one we will read, but the API is deliberately backend-agnostic. Mooncake (arXiv:2407.00079, met in Chapter 16) is both a paper and a backend in this tree: a KV-centric tiered store that a decode pool can pull from. The connector abstraction is what lets all of these be alternatives rather than forks.

The connector lifecycle, split across two sides

The base class documents its own contract better than I can paraphrase, so here is the seam. Every connector has a scheduler-side half and a worker-side half:

# vllm/distributed/kv_transfer/kv_connector/v1/base.py
    Scheduler-side: runs in the scheduler, binds metadata, which
    is used by the worker-side to load/save KV cache.
        get_num_new_matched_tokens() - get number of new tokens
            that exist in the remote KV cache. ...
    Worker-side: runs in each worker, loads/saves KV cache to/from
    the Connector based on the metadata.
        start_load_kv() - starts loading all KVs (maybe async)
        ...
        get_finished() - called with ids of finished requests, returns
            ids of requests that have completed async sending/recving.

Source: vllm/distributed/kv_transfer/kv_connector/v1/base.py

The split is the key to understanding everything that follows, so it is worth saying plainly what each half is and why there are two. The scheduler-side half lives in the engine’s step loop on the CPU. It never touches a GPU. Its job is purely to decide and to bookkeep: it answers “are these tokens available remotely?”, “have they arrived yet?”, and “may I free these blocks?”, and it emits small metadata structs describing what should move. The worker-side half lives inside execute_model, next to the GPU. Its job is to do the bytes: issue the RDMA reads, poll for completions, and hand a list of finished requests back up. Crucially, the two halves run on different clocks. The scheduler issues a transfer in one step but the bytes may not land for several steps, so the scheduler cannot simply block and wait. It has to park the request, keep stepping other work, and learn about completion asynchronously. That gap between “transfer requested” and “transfer done” is the central design problem of this whole chapter, and it is why a request needs a dedicated waiting state, which is exactly where the next section goes. The diagram below traces the two halves and the asynchronous handoff between them.

flowchart TD
    subgraph SCHED["scheduler-side (CPU, step loop)"]
        A["get_num_new_matched_tokens(): how many tokens are remote?"]
        B["park request in WAITING_FOR_REMOTE_KVS, reserve blocks"]
        C["each step: is request in finished_recving set yet?"]
        D["yes: rejoin schedulable pool, run forward"]
    end
    subgraph WORK["worker-side (GPU, execute_model)"]
        E["start_load_kv(): post RDMA read"]
        F["forward pass runs while bytes stream in"]
        G["get_finished(): poll completions"]
    end
    A --> B
    B -.->|"metadata describes the transfer"| E
    E --> F --> G
    G -.->|"finished_recving ids flow back up"| C
    C -->|"not yet"| C
    C --> D

The decode side: get_num_new_matched_tokens and the WAITING_FOR_REMOTE_KVS state

Start at the consumer. A request arrives at a decode instance with no local KV cache — the prompt was prefilled elsewhere. The decode scheduler must somehow learn that the tokens it would normally need to compute are already sitting in a remote prefill instance, allocate space for them, and not schedule any forward work until they have arrived. The hook for the first step is get_num_new_matched_tokens, called in the scheduler’s WAITING-queue pass right after the local prefix-cache lookup from Chapter 7:

# vllm/v1/core/sched/scheduler.py
if self.connector is not None:
    ext_tokens, load_kv_async = (
        self.connector.get_num_new_matched_tokens(
            request, num_new_local_computed_tokens
        )
    )

Source: vllm/v1/core/sched/scheduler.py

The connector returns two things: how many tokens it can supply beyond what the local prefix cache already has, and whether those tokens load asynchronously. For the NIXL consumer, a remote-prefill request reports its whole prompt as remotely available and flags the load as async:

# vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
if params is not None and params.get("do_remote_prefill"):
    # Remote prefill: get all prompt blocks from remote.
    token_ids = request.prompt_token_ids or []
    actual = self._mamba_prefill_token_count(len(token_ids))
    count = actual - num_computed_tokens
    if count > 0:
        return count, True

Source: vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py

That True is what reshapes the whole step. The scheduler allocates blocks for the external tokens, then parks the request instead of running it:

# vllm/v1/core/sched/scheduler.py
if load_kv_async:
    # If loading async, allocate memory and put request
    # into the WAITING_FOR_REMOTE_KV state.
    request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
    ...
    request.num_computed_tokens = num_computed_tokens
    self._inflight_prefills.add(request)
    continue

Source: vllm/v1/core/sched/scheduler.py

WAITING_FOR_REMOTE_KVS is a sibling of WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR from Chapter 14: a request that is admitted and has memory reserved, but is blocked on an external event before it can be scheduled. The distinction that makes it cheap is that it consumes KV blocks but no token budget — recall from Chapter 5 that the token budget is the scheduler’s per-step allowance of tokens it will actually push through the model. A parked request reserves the memory its KV will land in, but because no forward work runs for it, it does not eat into the budget that the running decodes are competing over. So a decode instance can hold a large backlog of inbound prefills without starving the tokens it is currently generating.

Notice too that num_computed_tokens is set to the full external count before the bytes arrive. This is deliberate optimistic bookkeeping: the scheduler pretends the tokens are already computed so that, once they land, the request can step straight into decode with nothing left to do. The comment in the code is careful to flag the optimism, because if the transfer fails those tokens get re-set to what was actually loaded, and the request falls back to recomputing the rest locally.

The request sits in that state until the worker-side connector reports the receive complete. The scheduler learns this through KVConnectorOutput.finished_recving, and only then does the request rejoin the schedulable pool:

# vllm/v1/core/sched/scheduler.py
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
    if request.request_id not in self.finished_recving_kv_req_ids:
        return False
    self._update_waiting_for_remote_kv(request)

Source: vllm/v1/core/sched/scheduler.py

_update_waiting_for_remote_kv does one more thing that ties back to Chapter 7. The transferred blocks are now real cached KV, so it registers them with the local block pool — a remote prefill becomes a local prefix-cache entry, free to be reused by the next request that shares the prompt. And it handles a subtle edge worth unpacking. Prefill computes two distinct things for the prompt: the KV cache (one key/value pair per token, which is what got transferred) and the logits (the probability distribution over the next token, which prefill produces only for the final position and which decode needs in order to sample the first output token). The transfer carries the KV but not the logits, so after a full-prompt transfer the decode instance has KV for every token yet no logits for the last one. The fix is to mark just the final position as not-yet-computed, forcing the model to run a single-token forward pass over it. That one pass re-derives the missing logits from the KV that is already present.

# vllm/v1/core/sched/scheduler.py
# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if request.num_computed_tokens == request.num_tokens:
    request.num_computed_tokens = request.num_tokens - 1

Source: vllm/v1/core/sched/scheduler.py

One position of recompute, never the whole prompt. That is the entire saving disaggregation buys, and it is exactly the saving prefix caching buys; the only difference is the cache lives across a network.

We have now seen both ends in isolation. The sequence diagram below stitches them into the single timeline of one request, from arrival at the decode instance through to the first sampled token, and shows where the parking and the pull-based transfer fall.

sequenceDiagram
    participant U as Client
    participant DS as Decode scheduler (D, CPU)
    participant DW as Decode worker (D, GPU)
    participant PW as Prefill worker (P, GPU)
    U->>DS: request, prompt already prefilled on P
    DS->>DS: get_num_new_matched_tokens reports N tokens remote, async
    DS->>DS: reserve blocks, state = WAITING_FOR_REMOTE_KVS
    Note over DS: consumes KV blocks, no token budget
    DW->>PW: start_load_kv posts RDMA READ of KV blocks (pull)
    PW-->>DW: KV bytes stream in over the wire
    DW->>DS: get_finished puts request id in finished_recving
    DS->>DS: register blocks in local prefix cache
    DS->>DS: mark last position uncomputed (need logits)
    DS->>DW: schedule single-token forward over last position
    DW->>U: first output token
    Note over DW,U: then normal decode, one token per step

The prefill side: holding blocks until they are read

Now the producer. A prefill instance does its normal compute-bound forward pass, but it must not free the KV blocks the instant the request “finishes,” because a decode instance still needs to read them. The scheduler-side request_finished is where it decides to defer:

# vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py
delay_free_blocks = any(len(group) > 0 for group in block_ids)
...
if delay_free_blocks:
    # Prefill request on remote. It will be read from D upon completion
    request_kv_blocks_ttl = self._kv_lease_duration
    ...
    self._reqs_need_send[request.request_id] = (
        time.perf_counter() + request_kv_blocks_ttl
    )

Source: vllm/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py

Returning delay_free_blocks=True tells the engine the connector has taken over responsibility for these blocks; they will not be freed when the request finishes. This is the same return contract request_finished uses for async offload saves in Chapter 16. The difference is the lease. A lease here means a time-bounded pin: rather than holding the blocks until some consumer definitely reads them (which might be never, if the consumer crashes), the producer pins them for only _kv_lease_duration seconds. If no decode instance shows up to read them before the lease expires, they are reclaimed anyway. For transfers that are legitimately still in flight, heartbeats (_heartbeat_by_engine) renew the lease so a slow-but-live reader is not cut off. The motivation is a failure mode: without the timeout, a crashed or slow consumer would strand prefill memory forever, and because every parked block is capacity the prefill pool cannot reuse, the pool would slowly bleed throughput until it could admit no new prompts. Distributed cache coherence under failure is, unsurprisingly, where most of the operational complexity hides.

The transfer itself is a pull, not a push, and the direction matters. A push would have the producer send bytes the instant prefill finishes, but the producer does not know where the blocks should land on the consumer until the consumer has allocated its own block slots, and it does not know whether the consumer even still needs them (it may already hold them in its prefix cache). A pull inverts this: the consumer, which knows its own layout and its own needs, reaches across RDMA and reads the blocks it wants from the producer’s memory. That is why all the initiating logic we saw lives on the decode side. The worker-side read for NIXL posts a point-to-point transfer per remote worker:

# vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
def _read_blocks(
    self,
    read_spec: ReadSpec,
    dst_engine_id: str,
    ...
):
    """
    Post a READ point-to-point xfer request from a single local worker to
    a single remote worker.
    """

Source: vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py

The full-prefix-cache-hit case is the elegant degenerate one: if the decode instance already has all the blocks locally, it reads nothing and merely sends a notification so the prefill side knows it can release its lease early.

# vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
if len(local_block_ids) == 0:

Source: vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py

Where the transfer hides inside the forward pass

The reason any of this stays off the critical path is overlap: the worker-side load and save are bolted onto execute_model rather than run as separate phases, so the bytes move while the GPU is busy with other work instead of during an idle stall. “Off the critical path” means the transfer’s latency hides inside time the GPU was going to spend computing anyway, rather than adding to the wall-clock the user waits. The model-runner mixin achieves this by wrapping the forward pass in a context manager that kicks off the load before the forward and harvests completions after:

# vllm/v1/worker/kv_connector_model_runner_mixin.py
# Background KV cache transfers happen here.
kv_connector.start_load_kv(get_forward_context())
try:
    yield output
finally:
    if wait_for_save and not defer_finalize:
        kv_connector.wait_for_save()

    output.finished_sending, output.finished_recving = (
        kv_connector.get_finished(scheduler_output.finished_req_ids)
    )

Source: vllm/v1/worker/kv_connector_model_runner_mixin.py

Reading it in order: start_load_kv fires the receive without waiting for it; the forward pass proceeds on the GPU while those bytes arrive in the background; then get_finished polls which transfers have actually completed and feeds those ids back to the scheduler as the finished_recving set, which is precisely the signal that springs WAITING_FOR_REMOTE_KVS requests loose in the section above. The overlap goes the other way too. On the producer side the layer-by-layer save_kv_layer / wait_for_save hooks let the KV for each transformer layer start streaming out the moment that layer is computed, rather than waiting for the whole forward pass to finish. So a request’s early-layer KV can already be on the wire to the consumer while the producer is still computing its later layers — the transfer overlaps the prefill that produced it.

This overlap is not free, and the cost lands on Chapter 10’s machinery. The connector base flags that these async layer operations cannot be captured in a CUDA graph and force requires_piecewise_for_cudagraph. A CUDA graph captures a fixed sequence of GPU launches and replays them with near-zero per-step overhead, but a transfer whose timing depends on the network is not a fixed sequence, so the graph has to be broken into pieces around it. The trade is explicit: you give up some of the launch-overhead savings from CUDA graphs in exchange for the ability to overlap transfer with compute at all.

What is still hard

Disaggregation is real and shipping, but it is not free goodput. The transfer must be amortized, meaning the cost of moving the KV has to be small next to the prefill cost it replaces; otherwise you have spent a network round trip to save a forward pass that was cheaper than the round trip. This is the same caveat both papers raised, now made concrete: NIXL exposes a kv_recompute_threshold (default 64 tokens) below which it simply recomputes the prompt locally on the decode instance rather than pulling it, because for a prompt that short the network round trip costs more than the prefill it would save. The two curves below make the crossover visible: local recompute cost grows linearly with prompt length from near zero, while the pull cost starts at a fixed RDMA round-trip floor and rises only gently, so for short prompts recompute wins and for long prompts the pull wins, and the threshold is just where they cross. Set it too low and you pull tiny prompts that would have been cheaper to recompute; set it too high and you recompute prompts that would have been cheaper to pull. You lose either way.

Illustrative: the curve shapes (linear recompute from zero, a fixed round-trip floor plus gentle slope for the pull) are the real relationship and the crossover is placed at NIXL’s 64-token default, but the absolute microsecond values are made up, not measured.

The deeper open problems are about balance and failure. The prefill and decode pools must be sized in proportion to the actual workload, because each pool is loaded by a different quantity: the prefill pool’s load scales with total prompt tokens per second $\lambda_{\text{prompt}}$, the decode pool’s with total output tokens per second $\lambda_{\text{output}}$. A workload of short prompts and long answers needs lots of decode and little prefill; a workload of long documents summarized in a sentence needs the reverse. The right split tracks the ratio $\lambda_{\text{prompt}} / \lambda_{\text{output}}$, which is a property of the traffic, not the deployment, and it drifts hour to hour, so a split sized for yesterday’s traffic is mis-sized today. This is exactly why disaggregation pushes hard on the autoscaling story of Chapter 20: the two pools have to be scaled independently and continuously. Two more sharp edges remain. Heterogeneous TP between the pools (the producer and consumer running different tensor-parallel degrees) means a single logical block is split across a different number of GPUs on each side, so the byte layouts do not line up, which is the whole reason _read_blocks carries that block_size_ratio remapping and staging-buffer logic. And the lease-and-heartbeat dance is a distributed-systems problem in its own right: every pinned block on a producer is memory a crashed consumer can strand until a timeout fires.

Which brings us to the obvious next question. If a request must be routed to a decode instance that can read from the prefill instance that holds its KV — and ideally to a prefill instance that already has its prompt prefix cached — then routing is no longer stateless. The router has to know each replica’s P/D role, its KV residency, and its queue depth, and steer accordingly. That is Chapter 18.

Further reading

  • DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving — arXiv:2401.09670 — disaggregates the two phases onto separate pools and optimizes for goodput (requests meeting both TTFT and TPOT SLOs); the foundational goodput-and-placement argument.
  • Splitwise: Efficient Generative LLM Inference Using Phase Splitting — arXiv:2311.18677 — splits phases across machine pools at fleet scale, including different GPU SKUs per pool and KV-transfer provisioning; read for the heterogeneous-hardware and capacity view.
  • Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving — arXiv:2407.00079 — KV-centric tiered store underpinning a disaggregated cluster; both a paper and a connector backend in vLLM’s tree.

Request routing and cache-aware load balancing

We are out of the engine now. For the whole of Parts II through IV the unit of attention was one vLLM process, or one model spread across a few GPUs that still presented itself as a single scheduler. This part of the book steps up a level: there are many replicas behind a virtual endpoint, requests arrive at some front door, and something has to decide which replica each request goes to. That something is the router, and the thesis of this chapter is that the router you would build out of habit, the one you have built a dozen times for stateless services, is wrong here for a specific and interesting reason.

Your instinct, honed on the kind of traffic infrastructure this book assumes you already know, is least-connections or least-requests-in-flight, maybe power-of-two-choices to avoid the herd. That instinct is built on an assumption: that replicas are interchangeable, so the only thing that distinguishes them at routing time is how busy they are. Send the request to the least-loaded box because any box can serve it equally well. For a stateless RPC that is exactly right.

It is wrong for LLM inference because replicas are not interchangeable, and the thing that distinguishes them is invisible to a connection counter. Chapter 7 showed that a replica that has recently served your system prompt is holding its KV blocks in the idle pool, content-addressed and ready to reuse. Route there and your 2,000-token prefill collapses to a cache lookup. Route to an identically-loaded replica that has never seen the prompt and you pay the full prefill on the critical path. Two replicas with the same connection count, the same queue depth, the same GPU utilization can differ by an order of magnitude in the TTFT they will give this particular request, and the difference is entirely about what is in their caches. Least-connections cannot see it. The right replica is the one that already holds your prefix and is not about to preempt to make room, and finding it means the router has to consume signals from inside the engine that a stateless load balancer never needed.

The diagram below contrasts the two worlds. On the left, the stateless router treats every replica as a black box and reads only one number, in-flight connections, so it sends the request to whichever box looks least busy. On the right, the cache-aware router peers inside each replica: it knows which prefixes are cached and how close each replica is to its memory wall, so it can route the request to the replica that already holds the prompt even when that replica is not the emptiest.

flowchart LR
    R["incoming request: prompt = SYS + user"]

    subgraph Stateless["stateless router (least-connections)"]
        direction TB
        SR["reads: open connections only"]
        SA["replica A: 3 conns, holds SYS prefix"]
        SB["replica B: 1 conn, cold cache"]
        SR -->|"fewer conns -> pick B"| SB
    end

    subgraph CacheAware["cache-aware router"]
        direction TB
        CR["reads: cache map + load gauges"]
        CA["replica A: 3 conns, holds SYS prefix"]
        CB["replica B: 1 conn, cold cache"]
        CR -->|"A holds prefix -> pick A"| CA
    end

    R --> SR
    R --> CR
    SB -.->|"full prefill on critical path"| Slow["slow TTFT"]
    CA -.->|"prefill collapses to cache hit"| Fast["fast TTFT"]

The rest of this chapter is the story of how the router on the right earns that knowledge: which numbers the engine exports cheaply, how the engine streams its cache contents so the router can rebuild a picture of where every prefix lives, and how a routing policy weighs locality against load when both signals are stale.

What the engine is willing to tell you

Start with the cheap signals, the ones every replica already exports. Chapter 2 introduced the Prometheus series; here we read them as a router, not an operator. Two gauges carry most of the load information you want, and PrometheusStatLogger defines them plainly:

# vllm/v1/metrics/loggers.py
gauge_scheduler_waiting = self._gauge_cls(
    name="vllm:num_requests_waiting",
    documentation="Number of requests waiting to be processed.",
    multiprocess_mode="mostrecent",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

# vllm/v1/metrics/loggers.py
gauge_kv_cache_usage = self._gauge_cls(
    name="vllm:kv_cache_usage_perc",
    documentation="KV-cache usage. 1 means 100 percent usage.",
    multiprocess_mode="mostrecent",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

These two are the load-balancing primitives for inference. num_requests_waiting is queue depth, the thing your least-connections instinct already reaches for, except it is the engine’s real queue, the WAITING set from Chapter 5, not a count of open sockets. kv_cache_usage_perc is the one with no analog in stateless serving: it tells you how close a replica is to the memory wall, and a replica near 1.0 is a replica about to preempt (Chapter 6). Routing a fresh request to a replica at 0.98 KV usage does not just queue it; it can trigger the eviction of cached prefixes other requests were counting on, a second-order cost a connection counter is structurally blind to. A good router treats high KV usage as a strong negative signal even when the queue looks empty, because the queue being empty and the cache being full is precisely the state that precedes a preemption storm.

Scraping Prometheus on an interval is fine for slow-moving decisions, but it has a freshness problem: if the router scrapes every two seconds, its picture of a replica can be two seconds stale, which is many requests of drift on a busy fleet. A router making a choice per request wants the load attached to the response it already has in hand, so that each completed request doubles as a fresh measurement of the replica that served it. vLLM supports exactly this through an ORCA-style load header. The mechanism is a small contract between client and server: the client adds a request header that names which load-report format it wants back, and the server reads that header on the way in.

# vllm/entrypoints/openai/chat_completion/api_router.py
metrics_header_format = raw_request.headers.get(
    ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
)

Source: vllm/entrypoints/openai/chat_completion/api_router.py

On the way out, the server attaches a response header built from the same live gauges, translating each Prometheus series name into the ORCA field name the client expects:

# vllm/entrypoints/serve/utils/orca_metrics.py
prometheus_to_orca_metrics = {
    "vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
    "vllm:num_requests_waiting": "num_requests_waiting",
}

Source: vllm/entrypoints/serve/utils/orca_metrics.py

This is the Open Request Cost Aggregation convention from the service-mesh world, repurposed for inference. ORCA was designed so that a backend could report its own load in-band, riding along on the responses it was already sending, instead of forcing the load balancer to poll it out of band. That is precisely the freshness fix we want here: every response carries the queue depth and KV usage of the replica that served it, measured at the instant of that response. A router that records the header from each completed request therefore holds load samples that are at most one request stale per replica, and it can run the equivalent of weighted-least-request against them with no separate scrape path at all. It is the same idea as the in-band server_load_metrics counter that load_aware_call maintains for simple in-flight tracking, but carrying the engine-internal numbers, queue depth and memory pressure, that actually matter for inference rather than a plain connection count.

The diagram below traces one request through this loop. Note that the load report the router learns from describes the previous state of the replica, the state just after it finished the prior request; this is the structural reason every signal in this chapter is slightly behind reality.

sequenceDiagram
    participant C as Client / router
    participant V as vLLM replica
    C->>V: request + header (report load as ORCA)
    Note over V: serve request, sample live gauges
    V->>C: response + header (kv_cache_usage_perc, num_requests_waiting)
    Note over C: record this replica's load for the next decision
    C->>V: next request, now routed with fresher load info

The signal least-connections can’t fake: who holds the prefix

Queue depth and KV usage make the router smarter about load. They say nothing about cache content, which is the whole reason replicas stopped being interchangeable. For that you need a different kind of signal, and Chapter 7 already showed you where it comes from: the block-lifecycle events the engine emits as its cache changes.

# vllm/distributed/kv_events.py
class BlockStored(KVCacheEvent):
    block_hashes: list[ExternalBlockHash]
    parent_block_hash: ExternalBlockHash | None
    token_ids: list[int]
    block_size: int

Source: vllm/distributed/kv_events.py

# vllm/distributed/kv_events.py
class BlockRemoved(KVCacheEvent):
    block_hashes: list[ExternalBlockHash]
    medium: str | None
    group_idx: int | None = None

Source: vllm/distributed/kv_events.py

Look at what BlockStored carries: a block’s hash and its parent_block_hash. We saw in Chapter 7 that these hashes are parent-chained, meaning each block’s hash is computed from its own tokens and the hash of the block before it, so a single hash names not just one block but the entire token prefix leading up to it. That chaining is what makes the stream useful. A sequence of (hash, parent_hash) pairs is exactly the edge list of a tree: each pair says “this block hangs off that parent,” and following the parent links from any node back to the root spells out a cached prefix token by token. The engine is, in effect, serializing its prefix tree onto the wire one edge at a time as it caches new blocks.

A router can reassemble that tree on the other side. Subscribe to one replica’s event stream, apply each BlockStored as a new edge and each BlockRemoved as a pruned one, and you hold an external, approximate copy of which prefixes live in that replica’s cache, rebuilt without ever reading a byte of its GPU memory. Subscribe to every replica’s stream and keep one such tree per replica, and you have a global map of where each prefix is cached. BlockRemoved prunes the tree as the engine evicts blocks; AllBlocksCleared resets a replica’s tree to empty. The router’s job at request time then reduces to a longest-prefix-match query against this map: hash the incoming prompt with the same scheme the engine uses, walk that hash chain against each per-replica tree, and route to the replica whose tree matches the deepest, because that is the replica that will reuse the most of your prompt and prefill the least.

The diagram below traces a single block-store event from the engine into the router’s reconstructed tree, then shows a request resolving against it.

flowchart TD
    subgraph Engine["vLLM replica: BlockPool"]
        E1["caches block for prompt prefix"]
        E2["emit BlockStored: hash=H3, parent=H2, tokens=..."]
        E1 --> E2
    end

    E2 -->|"PUB socket, async"| Pub["ZmqEventPublisher"]
    Pub --> Router

    subgraph Router["router: per-replica prefix tree"]
        T0["root"]
        T1["H1: SYS tokens"]
        T2["H2: SYS + tools"]
        T3["H3: SYS + tools + user"]
        T0 --> T1 --> T2 --> T3
    end

    Q["new request: hash prompt -> H1, H2, H3"] -->|"longest-prefix match"| Router
    Router -->|"deepest match here -> route to this replica"| Decision["chosen replica"]

The honest word in all of this is approximate, and it is worth being precise about why. The event stream is asynchronous and lossy by nature. ZmqEventPublisher ships events over a PUB socket, which is fire-and-forget: the publisher does not wait for subscribers to acknowledge, it only keeps a bounded replay buffer to help a reconnecting subscriber catch up. A subscriber that falls behind, or that reconnects after the buffer has rolled over, simply misses events, and its tree drifts away from the engine’s real cache. Even with no losses at all, there is an unavoidable lag: the router’s tree is always a few milliseconds behind the engine, so between the moment a block is actually evicted and the moment the matching BlockRemoved arrives, the router still believes the prefix is cached and can route a request to a replica that has already thrown it away.

There is a subtler trap too, which is double-counting. A replica is often not one process but several, and if every part emitted its own events the router’s tree would see the same block stored two, four, or eight times and could badly misjudge what is cached. vLLM handles this in two ways depending on where the events come from. For the core prefix cache the problem mostly does not arise: a tensor-parallel replica still has a single scheduler driving all its GPU ranks, and the BlockPool that emits these events lives in that one scheduler, so the router sees a single already-deduplicated stream per replica rather than one copy per rank. Where events genuinely do originate from many workers, as with the KV-transfer connectors of Chapters 16 and 17 that run independently per worker, vLLM deduplicates explicitly with KVEventAggregator. It counts how many workers have reported each event and emits an event only once it has been seen from all of them, so a block counts as truly cached on the replica only when every worker agrees it is:

# vllm/distributed/kv_events.py
def get_common_events(self) -> list[KVCacheEvent]:
    return [
        event
        for event, count in self._event_counter.items()
        if count == self._num_workers
    ]

Source: vllm/distributed/kv_events.py

Either way, the cache map is a hint, never a guarantee. A router built on it must degrade gracefully when the hint is wrong: a prefix-match miss should cost you a full prefill, not a failed request, and the routing policy has to blend the cache signal with the load signals rather than obey it blindly. Sending every request that shares a popular system prompt to the one replica that cached it first is how you turn a cache hit into a hotspot. This is the central tension of cache-aware routing, and it is worth saying plainly: you are balancing the locality benefit of reuse against the load benefit of spreading, with stale information about both.

The policy: locality versus balance

The research that named this tradeoff sharply is worth reading as the conceptual backbone of any router you build.

SGLang: Efficient Execution of Structured LM Programs (arXiv:2312.07104) is the RadixAttention paper from Chapter 7, but its serving side describes cache-aware routing across workers: keep an approximate radix tree of what each worker has cached and route to maximize prefix reuse. Read it for the framing that a router can hold a coarse model of every replica’s cache and act on it.

Preble: Efficient Distributed Prompt Scheduling for LLM Serving (arXiv:2407.00023) attacks exactly the hotspot problem above. It schedules prompts across replicas to balance prefix-cache locality against load, explicitly trading some reuse for spread when a popular prefix would otherwise overload one replica. Read it for the cost model that decides when locality is worth a load imbalance and when it is not.

The shape both arrive at is the same one your queueing intuition would predict once you accept the cache as state: routing becomes an optimization over a cost that has at least two terms. Write the depth of the matched prefix on replica $r$ as $d_r$, its queue depth as $w_r$, and its KV-cache utilization as $u_r$. The first term is a benefit, the prefill compute you save by hitting a warm prefix, which grows with $d_r$. The second term is a penalty, the queueing delay and preemption risk you incur by piling onto a replica that is already busy or near its memory wall, which grows with $w_r$ and $u_r$. A router scores each candidate replica by combining the two,

$$\text{score}(r) = \alpha \cdot d_r - \beta \cdot w_r - \gamma \cdot u_r,$$

where $\alpha$, $\beta$, and $\gamma$ weight locality against the two load penalties, deepest match pulling toward one replica, lightest load pulling toward another, and picks the replica that maximizes the combined score, $\arg\max_r \text{score}(r)$. When prefixes are long and reuse is high, the saved-prefill term is large and locality dominates, so you should chase the deepest match. When the cache map is stale or the best-matching replica is saturated, the load penalty swamps the locality benefit and balance dominates, so you fall back to least-loaded. The interesting routers are the ones that estimate both terms online from the very signals above, the BlockStored tree for the locality term, num_requests_waiting and kv_cache_usage_perc for the load term, and recompute the tradeoff per request.

The curve below makes the crossover concrete. It scores two candidates as the warm replica fills up: a replica that already holds a deep prefix match (high $d_r$) but whose KV utilization $u_r$ is climbing, against a cold replica that holds nothing ($d_r = 0$) but sits idle. While the warm replica has headroom its locality benefit keeps its score on top, so the router chases the prefix; but as $u_r$ rises the $-\gamma \cdot u_r$ penalty drags its score down until it falls below the cold, idle replica, and at that point the router spills to balance instead. The crossing point is exactly the utilization at which preemption risk outweighs the prefill you would have saved.

Illustrative weights ($\alpha=1,\beta=1,\gamma=10$; warm replica $d_r=8, w_r=2$; cold replica $d_r=0, w_r=0, u_r=0.1$): the locality-versus-load crossover is the durable point, not the exact numbers, which depend on how the router calibrates $\alpha$, $\beta$, $\gamma$ against measured prefill savings and queueing cost.

The flowchart below sketches that per-request decision. It is deliberately a soft decision, not a hard switch: a real router blends the two terms into a single score rather than taking one branch or the other, but the branches make the dominant force at each extreme legible.

flowchart TD
    Start["request arrives"] --> Hash["hash prompt, query per-replica prefix trees"]
    Hash --> Match{"any deep prefix match?"}
    Match -->|"no match"| LeastLoad["route by load only: lowest num_requests_waiting and kv_cache_usage_perc"]
    Match -->|"deep match found"| Busy{"is the matched replica near its limits? high queue or KV usage > threshold"}
    Busy -->|"replica has headroom"| Local["route to matched replica: reuse the warm prefix"]
    Busy -->|"replica saturated"| Spread["spill: route to a less-loaded replica, accept a colder cache"]
    LeastLoad --> Done["chosen replica"]
    Local --> Done
    Spread --> Done

P/D changes the question

Chapter 17 split prefill and decode onto separate pools. That fractures the routing problem into two, because the two pools care about different things. A prefill replica is compute-bound and short-lived per request; what matters there is cache locality (does it already hold the prefix?) and compute headroom. A decode replica is memory-bound and holds the request for its entire generation; what matters there is KV headroom and how many sequences it is already decoding. The same kv_cache_usage_perc gauge means something different on each: on a prefill node it is transient pressure, on a decode node it is the binding constraint for the request’s whole lifetime.

So a P/D-aware router is really two routers chained by a handoff. First it picks a prefill replica by prefix locality, the deepest-match logic from the section above, since that is where reuse pays off. It lets that replica compute the KV for the prompt. Then it picks a decode replica that has room to receive that KV and hold the request for the whole generation, scored by KV headroom rather than locality. Finally the KV cache itself is shipped from the prefill replica to the decode replica over the connector API from Chapters 16 and 17, and decode proceeds there. The router has to know each replica’s role, and crucially it has to make the two choices jointly: routing for prefill locality is pointless if the only available decode target has no KV headroom, because then the carefully-reused prefix just sits in the prefill replica’s hands, stalled, waiting for a decode slot to open. The locality win on the prefill side evaporates if the decode side cannot accept the result. This is the routing layer’s version of the same disaggregation bargain from Chapter 17, now expressed as a placement problem across two pools instead of a scheduling problem inside one engine.

The diagram below traces a request through the handoff and shows where each routing decision sits.

sequenceDiagram
    participant Cl as Client
    participant Ro as P/D router
    participant Pf as Prefill replica (by locality)
    participant De as Decode replica (by KV headroom)
    Cl->>Ro: request (prompt)
    Note over Ro: decision 1: pick prefill replica with deepest prefix match
    Ro->>Pf: route for prefill
    Note over Pf: compute full KV for the prompt
    Note over Ro: decision 2: pick decode replica with room
    Pf->>De: transfer KV over connector API
    De->>Cl: stream decoded tokens

Who actually consumes these signals

vLLM emits; it does not route. The routing lives in a small ecosystem of external components, and it helps to see them as consumers of exactly the signals above. The gateway-api-inference-extension, the Kubernetes-native effort, runs an endpoint picker that scrapes num_requests_waiting and kv_cache_usage_perc and increasingly consumes the KV-event stream for prefix awareness; it is the path most production fleets will reach for because it slots into existing ingress. sglang-router is the reference implementation of the radix-tree approach from the SGLang paper, maintaining the approximate per-worker cache trees and doing longest-prefix-match routing directly. Dynamo, NVIDIA’s serving framework, builds a global KV-aware router on the same block-event idea, layered with its own KV manager. llm-d assembles the Kubernetes pieces, the inference gateway plus vLLM’s events and metrics, into an opinionated cache-aware deployment. They differ in substrate and ambition, but they all reduce to the same loop: subscribe to block events to know where prefixes live, scrape or read-from-header the load gauges to know who is busy, and route each request to maximize reuse without creating a hotspot. The engine’s contribution is to make all of that observable; it was built to emit its cache state and its load, not to decide.

What is still unsolved

Be honest about the rough edges, because this is the youngest layer in the book. The cache map is fundamentally stale and lossy, and nobody has a clean answer for how aggressively to trust it; the safe default of blending it with load signals leaves real reuse on the table when the map happens to be accurate. Hotspot avoidance is a live tradeoff with no universal cost model, which is exactly why Preble is a paper and not a config flag. The signals themselves are coarse: a single kv_cache_usage_perc scalar collapses a whole prefix tree into one number, so the router learns that a replica is full but not which prefixes it would evict to admit you, which is the thing you actually need to predict your second-order cost. Multi-tenancy adds a routing dimension this chapter has only gestured at: a router that ignores LoRA adapters will scatter requests for one adapter across replicas that each then pay to load it, which is the problem Chapter 19 takes up next, where the vllm:lora_requests_info gauge becomes another routing signal and the per-adapter prefix-cache fork from Chapter 7 means the cache map itself has to be adapter-aware. And the whole edifice assumes prefixes repeat, an assumption the engine cannot enforce and the router can only exploit when the traffic above it cooperates.

The throughline of this chapter is the one to carry forward: state at the replica level turns load balancing from a stateless connection-counting problem into a placement problem over engine-internal signals, and the signals were there waiting because earlier chapters built the engine to expose them. Next we keep climbing the fleet: serving hundreds of fine-tuned variants from one base model, where the adapter becomes both a batching problem inside the engine and, as we just saw, another thing the router upstream has to route on.

Further reading

  • SGLang: Efficient Execution of Structured LM Programs — arXiv:2312.07104 — cache-aware routing across workers via approximate per-worker radix trees, the conceptual basis for prefix-locality routing.
  • Preble: Efficient Distributed Prompt Scheduling for LLM Serving — arXiv:2407.00023 — a cost model for trading prefix-cache locality against load when a popular prefix would overload one replica.

Multi-LoRA serving

Most of this book has assumed one model per replica. The token-budget scheduler from Chapter 5 batches requests against a single set of weights; the prefix cache from Chapter 7 reuses KV blocks across requests that share context; the tensor-parallel sharding from Chapter 15 splits one weight matrix across GPUs. None of that asked what happens when the requests in a batch want different weights.

In practice they often do. A platform that fine-tunes a base model for hundreds of customers, or hundreds of tasks, ends up with hundreds of variants that differ from the base by a tiny amount. Low-rank adaptation (LoRA) is what makes this cheap: instead of a full fine-tune, you learn a pair of small matrices $A$ and $B$ per target layer such that the effective weight is $W + \frac{\alpha}{r} \cdot B A$, where $r$, the rank, is typically 8 to 64. To unpack that formula: $W$ is the original weight matrix, say of shape $d_{\text{out}} \times d_{\text{in}}$; $A$ is $r \times d_{\text{in}}$ and $B$ is $d_{\text{out}} \times r$, so their product $BA$ has the same shape as $W$ but is forced to be low-rank, because it factors through the narrow $r$-dimensional waist. The scalar $\frac{\alpha}{r}$ rescales the delta so that changing the rank does not change its typical magnitude. The base W is frozen and shared; the two thin matrices that make up the adapter are a few megabytes. The modeling question is settled. The serving question is not, and it is the one this chapter is about: how do you batch a request for adapter 7 next to a request for adapter 113 next to a request for the base model, run them through the same fused kernels, and still apply each request’s own low-rank delta?

The naive answer is to refuse the problem: give each adapter its own replica. That throws away the entire premise. The base weights dominate memory, and a replica serving one rarely-used adapter wastes a whole GPU’s worth of frozen weights to host a few megabytes of delta. The whole point is statistical multiplexing across adapters on shared base weights, which means the batch will be heterogeneous, and the engine has to make a heterogeneous batch run at homogeneous-batch efficiency. As the title promises, this is a memory-and-batching problem, not a modeling one.

The wrong way and the Punica way

Start with what you must avoid. You could merge each adapter into the base weights ($W’ = W + BA$) and run a clean single-model forward pass. That is fast per request and catastrophic across requests: merging is per-adapter, so a batch with $k$ distinct adapters needs $k$ distinct weight matrices, and you are back to one-model-per-batch with the added cost of merging on every switch. Conversely, you could keep adapters separate and loop over the batch, running a small GEMM per request. Correct, but a decode step is already memory-bound (Chapter 3), and launching a tiny kernel per sequence drowns the GPU in launch overhead — exactly the per-step cost Chapter 10 worked to kill.

The way out is a batched LoRA kernel that handles many adapters in one launch by carrying a per-token index telling it which adapter’s weights to use. The trick is to stop thinking of the batch as a set of requests, each with its own weights, and start thinking of it as a flat list of token rows, each carrying a small integer that names its adapter. The kernel then does one big matrix multiply over all the rows at once, but for each row it reads its index and gathers the matching adapter’s $A$ and $B$ from a shared pool. Many adapters, one launch, no per-request loop. This is the Punica contribution, and vLLM’s implementation cites it directly:

# vllm/lora/punica_wrapper/punica_base.py
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""

Source: vllm/lora/punica_wrapper/punica_base.py

The base-layer math, applied to the output of a wrapped linear layer, is two GEMMs — a shrink down to rank $r$ and an expand back to the output dimension:

# vllm/lora/punica_wrapper/punica_gpu.py
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y, buffer, lora_b_stacked, output_slices,
                add_inputs=add_inputs, **kwargs)

Source: vllm/lora/punica_wrapper/punica_gpu.py

Read those two calls as the two halves of the delta $\frac{\alpha}{r} \cdot B A$ applied to the input x. The shrink computes $x A$ and lands in buffer, the rank-$r$ intermediate: it projects each token’s hidden vector down from the model width $d_{\text{in}}$ to the narrow rank $r$. It is computed at float32 for numerical headroom. The expand then computes $\text{buffer} \cdot B$ and accumulates it into the base output y, projecting back up from $r$ to the output width $d_{\text{out}}$. Splitting the delta into shrink-then-expand is what keeps it cheap: instead of one $d_{\text{out}} \times d_{\text{in}}$ multiply per token you do two skinny multiplies through the $r$-wide waist, which is the whole reason low rank is affordable. Note the sequencing: the base linear has already produced y for every token; LoRA is a residual added on top, never a replacement, which is exactly why a base-model request and an adapted request can ride in the same batch.

What makes it multi-tenant is that lora_a_stacked and lora_b_stacked are not single matrices. They are stacks indexed by an adapter slot (a small integer naming one shelf in the GPU’s adapter pool), and the kernel reads a per-token mapping to pick the right slice for each row of x. The diagram below traces one batched launch: four token rows belonging to three different adapters and the base model, all flowing through a single shrink/expand pair that selects each row’s A/B by its slot index.

flowchart TD
    subgraph BATCH["input rows x (one per token)"]
        T0["row 0  slot 0"]
        T1["row 1  slot 0"]
        T2["row 2  slot 2"]
        T3["row 3  base, index -1"]
    end
    subgraph POOL["stacked adapter buffers (indexed by slot)"]
        S0["slot 0: A0 / B0"]
        S1["slot 1: A1 / B1"]
        S2["slot 2: A2 / B2"]
    end
    BATCH --> SHRINK["shrink: buffer = x @ A[slot]  (down to rank r)"]
    POOL -->|"A[slot]"| SHRINK
    SHRINK --> EXPAND["expand: y += buffer @ B[slot]  (back to d_out)"]
    POOL -->|"B[slot]"| EXPAND
    BASE["base linear y = x @ W (runs for all rows)"] --> EXPAND
    EXPAND --> OUT["output y (base + per-row LoRA delta; index -1 skipped)"]

One launch, heterogeneous adapters, no merging, no per-request loop.

Slots, not adapters

The kernel indexes by slot, and slots are a fixed, small resource — this is the memory side of the problem. LoRAConfig names the two numbers that govern everything downstream:

# vllm/config/lora.py
max_lora_rank: MaxLoRARanks = 16
"""Max LoRA rank."""
max_loras: int = Field(default=1, ge=1)
"""Max number of LoRAs in a single batch."""

Source: vllm/config/lora.py

max_loras is the number of GPU-resident adapter slots — how many distinct adapters can appear in one batched kernel launch. max_cpu_loras (defaulting to max_loras) is how many adapters are kept staged in CPU memory. The key design choice is that the GPU-resident weight buffers are allocated once, up front, at max_loras * max_lora_rank and padded, rather than allocated per adapter on demand. That has two consequences worth making explicit. First, a batch can never ask for more slots than were provisioned, so the kernel’s indexing is always in bounds. Second, every adapter occupies a slot sized for max_lora_rank regardless of its actual rank, so a slot is a fixed-size shelf you swap adapters into and out of, not a custom-fit allocation. Fixed shelves are what avoid fragmentation: with per-adapter allocations of varying size, a churning population of adapters would leave the GPU’s adapter memory pocked with holes too small to reuse. This is the S-LoRA insight in vLLM’s shape: a unified pool of uniform adapter slots with a fixed budget, rather than per-adapter allocations that fragment as the adapter population churns.

Activation is the act of binding a registered adapter to a free slot index and copying its weights into the stacked GPU buffers:

# vllm/lora/model_manager.py
first_free_slot = next(
    ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
     if lora_id is None), None)
if first_free_slot is None:
    raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_adapters[lora_id] = None
lora_model = self._registered_adapters[lora_id]
self.lora_index_to_id[index] = lora_model.id

Source: vllm/lora/model_manager.py

lora_index_to_id is the slot table: a list of length max_loras mapping each physical slot to the adapter currently living in it, or None for free. It is the single source of truth that ties the two worlds together. On one side, activation finds a hole (a slot holding None), records the binding, and then walks every wrapped module to copy that adapter’s A/B matrices into slot index of the stacked GPU buffers via module.set_lora(index, ...). Deactivation just nulls the slot. On the other side, the batched kernel from the previous section reads this same table’s inverse, the per-token slot indices, to address adapters by a small integer. That indirection is the whole point: the kernel never needs a weight pointer or an adapter id, only “which shelf,” and the slot table is what assigns and reclaims shelves.

When more adapters are wanted than there are slots, the same LRU logic the prefix cache used in Chapter 7 reappears, one level up. The active set is an LRU cache keyed by adapter id; activating a new adapter when the GPU slots are full evicts the least-recently-used one:

# vllm/lora/model_manager.py
if (lora_id not in self._active_adapters
        and len(self._active_adapters) >= self.lora_slots):
    self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
self._active_adapters.touch(lora_id)

Source: vllm/lora/model_manager.py

There are two such caches stacked, forming a memory hierarchy for adapters that mirrors the disk/CPU/GPU hierarchy you already know for KV blocks: a CPU-side LRU of registered adapters (capacity max_cpu_loras) and a GPU-side LRU of activated adapters (capacity max_loras). An adapter starts cold on disk; on registration it is loaded into CPU memory; on first use in a batch it is staged from CPU into a free GPU slot (or one freed by evicting the least-recently-used resident adapter); and as long as it keeps getting used, recency keeps it resident. The worker_manager.py loader stamps the docstring on its own class: “Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.” The state diagram below traces one adapter through these tiers and the events that move it between them.

stateDiagram-v2
    [*] --> OnDisk: adapter checkpoint exists
    OnDisk --> InCPU: register (load to host RAM)
    InCPU --> OnDisk: CPU LRU evict (max_cpu_loras full)
    InCPU --> InGPU: activate (copy A/B into a free slot)
    InGPU --> InCPU: GPU LRU evict (max_loras full, slot needed)
    InGPU --> InGPU: touch (used this step, stays warm)
    note right of InGPU
        occupies one of max_loras slots
        addressable by the batched kernel
    end note

Adapter residency is a caching problem with the same eviction pressure, the same locality assumptions, and the same cliff when the working set exceeds capacity that you already know from KV blocks. The two capacities matter independently: max_cpu_loras bounds how much disk reloading you suffer, and max_loras bounds how much GPU weight-copying you suffer, and a workload can be bottlenecked on either tier.

Building the per-token map every step

The batched kernel needs, for each token in the batch, the slot index of that token’s adapter (or -1 for the base model). But the scheduler thinks in requests, and the kernel thinks in rows, so something has to translate one into the other on every step. That translation is what this section is about: turning “request 42 wants adapter 7, request 43 wants the base model” into a flat per-token vector of slot indices the GPU can read. The v1 model-runner mixin assembles it from the input batch:

# vllm/v1/worker/lora_model_runner_mixin.py
prompt_lora_mapping, token_lora_mapping, lora_requests = (
    input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
)
return self._set_active_loras(
    prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type)

Source: vllm/v1/worker/lora_model_runner_mixin.py

Two mappings, because two parts of the model need adapter indices at different granularities. token_lora_mapping has one entry per scheduled token, for the linear-layer LoRAs that run over every position (every token passes through the attention and MLP projections). prompt_lora_mapping has one entry per sampled token, for the logits/sampler LoRA that only runs at the last position of each sequence, where the next-token distribution is produced. The two granularities exist because the model applies LoRA in two different places that see different numbers of rows: thousands of token positions in the layers, but only one sampled position per request at the head.

Note this is recomputed every step against the current batch composition. Continuous batching (Chapter 5) means the set of requests, and therefore the set of active adapters, changes step to step; the LoRA index tensors are part of the per-step metadata the scheduler hands the worker, no different in spirit from the block tables. The sequence diagram below traces a single decode step, from the scheduler’s request-level decisions down to the kernel reading slot indices.

sequenceDiagram
    participant Sched as Scheduler
    participant Batch as InputBatch
    participant Mixin as ModelRunner mixin
    participant Punica as PunicaWrapper on GPU
    participant Kernel as Triton shrink/expand kernel
    Sched->>Batch: requests admitted this step (each with adapter id)
    Batch->>Mixin: make_lora_inputs returns token + prompt mappings
    Mixin->>Punica: set_active_loras (host-side mapping)
    Punica->>Punica: update_metadata maps ids to slot indices via lora_index_to_id
    Punica->>Kernel: token_lora_indices, sampler_indices (-1 = base/unslotted)
    Kernel->>Kernel: per row: skip if -1, else shrink/expand from that slot

On the GPU, update_metadata turns the host-side mapping into the index tensors the Triton kernels consume:

# vllm/lora/punica_wrapper/punica_gpu.py
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)

Source: vllm/lora/punica_wrapper/punica_gpu.py

A token whose adapter is not currently slotted, or a base-model request, carries index -1, and the kernel’s convention is that -1 means “skip”: for those rows the shrink and expand simply do not write, leaving the base output untouched. This single sentinel value is what makes the whole scheme uniform. There is no separate code path for base-model requests and no special-casing in the scheduler; the base forward pass produces y for everyone, and the LoRA kernel adds a delta only where the per-token index points at a real slot. Base and adapted requests differ by one integer in a vector, nothing more.

Tensor parallelism, packing, and where the sharding goes

Chapter 15 sharded the base weights across a TP (tensor-parallel) group, splitting each weight matrix across GPUs so each rank holds only a slice. The adapters have to follow the same split, otherwise the LoRA delta would not line up with the base output it is being added to, and how they shard differs by layer type. This is why vLLM wraps each parallel-linear flavor in its own LoRA layer. Take ColumnParallelLinear, which shards its output dimension: each rank computes one chunk of the output columns. The base output y is therefore already partitioned across ranks, so the LoRA delta must be partitioned the same way. Since $B$ is the matrix that produces the output, $B$ shards by column to match, while $A$ (which only reaches the rank-$r$ waist) is replicated on every rank. The consequence is that the shrink $x A$ is computed redundantly on each rank, and only the expand $\text{buffer} \cdot B$ is sharded. A RowParallelLinear, which shards its input dimension, is the mirror image: $A$ shards and $B$ is replicated. The wrapper layers (column_parallel_linear.py, row_parallel_linear.py) encode exactly this by overriding slice_lora_a/slice_lora_b to cut each adapter’s weights to the local shard at activation time.

That replicated half is wasted work: every rank redoes the same small GEMM. The optional fully_sharded_loras flag pushes the split further, sharding the otherwise-replicated half too at the cost of an extra communication step to reassemble it. It pays off “at high sequence length, max rank or tensor parallel size” per the config docstring, because that is exactly when the redundant compute grows large enough to outweigh the added communication. It is the same communication-versus-redundant-compute tradeoff that governs base-weight sharding, and you can read it off the same roofline reasoning as everything else.

There is a packing wrinkle worth noting because it interacts with how real checkpoints are laid out. Models fuse q, k, v into one qkv_proj GEMM and gate/up into one gate_up_proj. The adapter weights for those fused projections arrive as separate slices and get packed to match:

# vllm/lora/lora_weights.py
class PackedLoRALayerWeights(LoRALayerWeights):
    """LoRA used for packed layers (eg. qkv_proj)."""

Source: vllm/lora/lora_weights.py

pack also folds the per-adapter scaling $\frac{\alpha}{r}$ into $B$ once, at load time, so the hot path multiplies by $1$ instead of recomputing a scale per step — a small optimize() that matters because it runs inside the batched kernel’s accounting. MoE adds a third axis: each expert is its own GEMM, so a LoRA-adapted MoE layer needs adapter weights stacked over (num_experts, rank, ...) and a fused kernel that routes tokens to experts and to adapter slots simultaneously (punica_wrapper’s add_lora_w13/add_lora_w2, building on the expert-parallel routing of Chapter 15). The complexity compounds, but the principle does not change: one batched launch, per-token indices selecting both expert and adapter.

The prefix cache forks per adapter

Chapter 7 left a hook we now have to honor. Prefix caching rests on one assumption: the KV (the cached keys and values for a span of tokens) is a pure function of the tokens, so two requests with byte-identical prefixes can share the same KV blocks. Adapters break that assumption. The keys and values are produced by the k_proj and v_proj attention projections, and those projections are precisely the weights LoRA modifies. So the same prefix tokens run through adapter 7 produce different keys and values than through adapter 113, even though the token ids are identical. The KV is now a function of the tokens and the adapter. Reusing one adapter’s KV for another would silently feed wrong keys and values into attention, a correctness bug that produces plausible-looking but incorrect output. The fix is the extra_keys mechanism from Chapter 7, forked per adapter:

# vllm/v1/core/kv_cache_utils.py
def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
    if not request.lora_request:
        return []
    return [request.lora_request.lora_name]

Source: vllm/v1/core/kv_cache_utils.py

The adapter name folds into the block hash, so KV blocks computed under adapter 7 are addressable only by requests also using adapter 7. Prefix sharing now happens within an adapter’s request stream, not across adapters. This is the right semantics, and it has a real fleet consequence: a popular system prompt that would have been computed once is now computed once per adapter that uses it, fragmenting cache reuse along adapter lines. There is no free fix; correctness requires the fork. It is, however, the kind of thing a cache-aware router should know about, which is the next handoff.

Backpressure and the routing signal

The slot budget is a hard scheduling constraint, not a soft preference, and the token-budget scheduler from Chapter 5 enforces it directly. When considering a waiting request, the scheduler checks whether admitting it would need a slot beyond max_loras:

# vllm/v1/core/sched/scheduler.py
if (self.lora_config and request.lora_request
        and (len(scheduled_loras) == self.lora_config.max_loras
             and request.lora_request.lora_int_id not in scheduled_loras)):
    # Scheduling would exceed max_loras, skip.
    request_queue.pop_request()
    step_skipped_waiting.prepend_request(request)
    continue

Source: vllm/v1/core/sched/scheduler.py

A request whose adapter would not fit is deferred, not dropped — it goes back to the waiting queue and is retried next step, the same load-shedding-by-waiting valve from Chapter 6, now keyed on adapter slots rather than KV blocks. This is why the engine exposes adapter occupancy as a first-class metric. The Prometheus gauge reports which adapters are running, which are waiting, and the slot ceiling:

# vllm/v1/metrics/loggers.py
self.gauge_lora_info = self._gauge_cls(
    name="vllm:lora_requests_info",
    documentation="Running stats on lora requests.",
    ...
    labelnames=[self.labelname_max_lora,
                self.labelname_waiting_lora_adapters,
                self.labelname_running_lora_adapters])

Source: vllm/v1/metrics/loggers.py

That gauge is exactly the kind of engine-internal signal Chapter 18 argued a router must consume. Least-connections routing is wrong here for the same reason it was wrong for prefix locality: the right replica for a request is one that already has that adapter slotted (and won’t have to evict to admit it) and that holds the request’s prefix in the adapter-forked cache. A router that scrapes vllm:lora_requests_info can steer adapter-7 traffic toward replicas already serving adapter 7, packing the adapter working set per replica instead of smearing every adapter across every replica and forcing constant activation churn. The slot budget and the cache fork both push toward adapter affinity in routing, the multi-tenant cousin of the prefix affinity from Chapter 18.

What is still hard

The headline win is real: hundreds of adapters multiplexed on shared base weights, batched without merging, at close to base-model throughput when the active set fits in slots. The honest caveats are about that last clause and its edges.

Slot thrash is the dominant failure mode. If the live adapter population exceeds max_loras, every step risks evicting an adapter the next step needs, and activation is not free: it copies weights into the stacked GPU buffers and, on a CPU-cache miss, reads them from disk first. A workload with a long tail of rarely-used adapters can spend real time shuffling weights in and out, and the symptom (requests deferred, latency spiking) looks like load when it is actually adapter cache pressure. Raising max_loras costs GPU memory provisioned at full rank whether or not adapters use it; the tradeoff is genuine and workload-specific.

The shape of that failure is the same eviction cliff the prefix cache had in Chapter 7, one level up. The curve below sketches sustained throughput as the live adapter working set grows against a fixed slot budget: while distinct adapters fit in max_loras, every activation is a hit and throughput holds near the base-model rate; once the working set crosses the slot ceiling, each step starts evicting an adapter the next step wants, and throughput falls off as activation cost (GPU weight copies, and disk reloads on a CPU-cache miss) crowds out useful compute.

Illustrative: the flat region (working set within the 16-slot budget) and the post-cliff falloff (thrash as activation cost dominates) have the right shape, but the exact slope of the decline depends on the workload’s adapter reuse pattern and the disk-versus-CPU mix of activations.

Heterogeneous rank is the quiet tax. Buffers are sized at max_lora_rank, so a deployment mixing rank-8 and rank-64 adapters either provisions every slot for 64 (wasting capacity on the small ones) or caps the large ones.

And the cache fork, while correct, is a standing efficiency loss: shared context is recomputed per adapter, with no general way to share the part of the KV the adapter did not change, because in attention the adapter changes the very projections that produce the KV. The S-LoRA and Punica papers below go deeper on slot management and batched-kernel design respectively; both predate the MoE and fully-sharded refinements now in the tree, so read them for the load-bearing ideas, not the current shapes.

With adapters multiplexed and the routing signal exposed, the fleet can serve many models from few replicas. What it cannot yet do is bring a new base model online quickly, or decide how many replicas to run at all. That cold-start-and-scale problem is Chapter 20.

Further reading

  • S-LoRA: Serving Thousands of Concurrent LoRA Adapters — arXiv:2311.03285 — unified paged adapter memory and a fixed slot budget that let one replica hold thousands of adapters; the source of vLLM’s slot-and-eviction model.
  • Punica: Multi-Tenant LoRA Serving — arXiv:2310.18547 — the batched LoRA (SGMV/BGMV) kernel that applies many adapters in one launch via per-token indices; vLLM’s punica_wrapper is named for and cites it.

Deployment, model loading, and autoscaling the fleet

Every chapter so far has assumed the replica is already running. The KV-cache manager from Chapter 4, the token-budget scheduler from Chapter 5, the cache-aware router from Chapter 18 — all of them presuppose a warm process with weights resident in GPU memory, ready to schedule the next step. That assumption is the one this chapter pulls out from under you.

Because the thing your autoscaler does when traffic spikes is create a cold replica, and a cold replica is useless until its weights are on the GPU. A “cold” replica here means a freshly started process that has no model weights in GPU memory yet — it cannot run a single forward pass, because the matrices it needs to multiply are not there. The interval between “scale-up decision” and “first token served” is not network setup or container pull (those you can hide); it is dominated by reading tens or hundreds of gigabytes of weights off storage and landing them in HBM. HBM is High Bandwidth Memory, the GPU’s own on-package RAM where weights must live for the matrix multiplies to read them at full speed. A 70B model in bf16 is $70 \times 10^9 \times 2\ \text{bytes} = 140\ \text{GB}$. Even at a healthy 5 GB/s from local NVMe, that is $140 / 5 = 28$ seconds, nearly half a minute of pure I/O before the scheduler can admit a single request. So the cold-start cost and the autoscaling policy are not two separate concerns. They are the same latency story told from two ends: how fast can a loader fill HBM, and how do you arrange your fleet so you rarely have to pay full price for it.

The diagram below traces the path those bytes take and names the three places they can get stuck — the same three places the rest of this chapter attacks one by one.

flowchart LR
    A["scale-up decision"] --> B["start cold process"]
    B --> C["storage: NVMe / NFS / object store"]
    C -->|"read bytes"| D["CPU RAM (staging)"]
    D -->|"copy H2D"| E["GPU HBM (resident weights)"]
    E --> F["scheduler can admit first request"]
    C -.->|"lever 1: prefetch sequentially"| D
    C -.->|"lever 2: stream straight to GPU"| E
    C -.->|"lever 3: contiguous blob format"| E

This is the reader’s home turf — autoscaling on a queue-depth signal is bread and butter for a traffic engineer. What is unfamiliar is the scale-up latency, which in stateless RPC land is milliseconds and here is seconds-to-minutes, and which is set almost entirely by code in vllm/model_executor/model_loader/.

The loader is a dispatch, not an algorithm

vLLM does not have one way to load weights; it has a family of loaders, and which one runs is selected by a single config field rather than by any runtime logic. A loader, concretely, is a class that knows how to take a checkpoint on storage and populate the model’s tensors with it. The menu lives in vllm/config/load.py:

# vllm/config/load.py
@config
class LoadConfig:
    """Configuration for loading the model weights."""

    load_format: str | LoadFormats = "auto"

The docstring beneath it enumerates the formats — auto, safetensors, runai_streamer, tensorizer, bitsandbytes, sharded_state, and more — and each maps to a BaseModelLoader subclass (the common interface every loader implements). The point worth internalizing is that this is the only knob whose default ("auto") gives you the slowest reasonable path, and every other value is someone trading generality for cold-start latency. The default loader is correct everywhere; the fast loaders are correct somewhere specific. So the engineering question this chapter keeps returning to is: what do you know about your storage and your checkpoint that lets you pick a faster loader without breaking correctness?

The default path is in vllm/model_executor/model_loader/default_loader.py. Its job is unglamorous: figure out which files on disk (or on the Hub — HuggingFace’s model repository) hold the weights, then stream them tensor by tensor into the model. “Tensor by tensor” matters: the model is not one giant array but thousands of named weight matrices (one per layer per projection), and the loader walks them one at a time. The file-resolution logic is a cascade of format guesses — it does not know in advance whether the checkpoint is .safetensors or the older .bin pickle format, so it picks a glob pattern based on load_format:

# vllm/model_executor/model_loader/default_loader.py
if load_format == "hf":
    allow_patterns = ["*.safetensors", "*.bin"]
elif (
    load_format == "safetensors"
    or load_format == "fastsafetensors"
    or load_format == "instanttensor"
):
    use_safetensors = True
    allow_patterns = ["*.safetensors"]

Once it knows the files, load_weights simply hands an iterator to the model and times the result:

# vllm/model_executor/model_loader/default_loader.py
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))

self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
    "Loading weights took %.2f seconds",
    self.counter_after_loading_weights - self.counter_before_loading_weights,
)

Source: vllm/model_executor/model_loader/default_loader.py

That time.perf_counter() delta is the cold-start cost, logged on every boot. (time.perf_counter() is a high-resolution wall-clock timer; subtracting the reading taken before loading from the one taken after gives the elapsed seconds.) It is the number your autoscaler is implicitly fighting, and the number every alternate loader exists to shrink.

Where the seconds go, and three ways to claw them back

To see why there are exactly three levers, it helps to remember the path from the first diagram: bytes travel storage → CPU RAM → GPU HBM. Each lever attacks a different segment of that path. The first attacks the storage read pattern, the second removes the CPU hop, and the third removes the parsing work by changing the on-disk format. The diagram below puts the three side by side against the default path so you can see what each one skips.

flowchart TD
    subgraph Default["default safetensors (slowest, works everywhere)"]
        d1["storage: mmap, lazy random reads"] --> d2["CPU RAM"] --> d3["GPU HBM"]
    end
    subgraph L1["lever 1: prefetch strategy"]
        a1["storage: large sequential reads, page cache warmed"] --> a2["CPU RAM"] --> a3["GPU HBM"]
    end
    subgraph L2["lever 2: Run:ai streamer"]
        b1["object store: S3 / GCS / Azure"] -->|"concurrent streams, no CPU bounce"| b3["GPU HBM"]
    end
    subgraph L3["lever 3: tensorizer format"]
        c1["pre-serialized contiguous blob"] -->|"one read, no parse"| c3["GPU HBM"]
    end

The default safetensors path memory-maps each file and copies tensors out lazily. “Memory-mapping” (mmap) means the file is presented to the program as if it were already an array in memory; the operating system fetches the actual bytes from storage only when a tensor is touched. “Lazily” means a tensor’s bytes are read at the moment the model first asks for that tensor, not up front. On local SSD that is fine — a page fault hits fast local flash. On a network filesystem it is a disaster, because each page fault becomes a round trip over the network, so mmap turns into a storm of small random reads, each paying network latency. vLLM’s first lever is brute-force prefetching: warm the OS page cache (the kernel’s in-RAM buffer of recently read file pages) with large sequential reads before the loader touches the file, so that when mmap later faults, the bytes are already local. The work is sharded across ranks so a tensor-parallel group — the set of GPUs that split one model’s matrices among themselves, from Chapter 15 — does not all hammer the same bytes. Each rank takes a strided slice of the file list:

# vllm/model_executor/model_loader/weight_utils.py
if torch.distributed.is_initialized():
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
else:
    rank = 0
    world_size = 1
paths_to_prefetch = sorted_files[rank::world_size]

Source: vllm/model_executor/model_loader/weight_utils.py

The safetensors_load_strategy field documented in load.py (lazy, eager, prefetch) is exactly this lever exposed: on NFS or Lustre (network and parallel-cluster filesystems) you tell the loader to read everything up front instead of paging it in randomly. Note what this does not change — the bytes still travel storage → CPU RAM → GPU, the leftmost path in the diagram above. It only changes the access pattern (sequential instead of random), and for network storage that alone can be the difference between two minutes and twenty seconds, because one large sequential read amortizes the network latency that thousands of small random reads each pay in full.

The second lever is to skip the CPU bounce entirely. In the default path every byte lands in CPU RAM first and is then copied across the PCIe bus to the GPU — two trips through memory. The Run:ai Model Streamer loader (runai_streamer_loader.py) instead streams safetensors straight from object storage — S3, GCS, Azure Blob, the bucket APIs cloud replicas actually pull from — into GPU memory, overlapping many transfers at once. “Concurrency” here is how many object-store reads are in flight simultaneously, and a “memory limit” caps how much it buffers while doing so; both are tunable through environment variables the loader sets:

# vllm/model_executor/model_loader/runai_streamer_loader.py
if isinstance(concurrency := extra_config.get("concurrency"), int):
    os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(concurrency)
if isinstance(memory_limit := extra_config.get("memory_limit"), int):
    os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(memory_limit)

This matters for autoscaling specifically because cloud replicas pull from object storage, not a local disk image. A streaming loader with high concurrency turns “download the whole checkpoint, then load it” into one overlapped pipeline, which is the difference that lets you bake a thin container and keep weights in a bucket.

The third lever is to change the format so loading is a single contiguous blob read rather than a parse. The default safetensors path spends real time figuring out where each tensor lives in the file and reshaping it into the model’s layout — that is the “parse” work. CoreWeave’s tensorizer (tensorizer.py) sidesteps it by serializing a vLLM model once into a layout that matches what the device wants, so deserialization is little more than copying a contiguous blob onto the GPU:

# vllm/model_executor/model_loader/tensorizer.py
deserializer.load_into_module(model)
end = time.perf_counter()

total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)

It even reports throughput, because throughput is the whole pitch. But there is an honest caveat baked into the source: tensorizer is only fast when the checkpoint was serialized from vLLM. Hand it a vanilla HuggingFace repo and it warns you it will fall back to a CPU load:

# vllm/model_executor/model_loader/tensorizer.py
logger.warning(
    "Deserializing HuggingFace models is not optimized for "
    "loading on vLLM, as tensorizer is forced to load to CPU. "
    "Consider deserializing a vLLM model instead for faster "
    "load times. ..."
)

Source: vllm/model_executor/model_loader/tensorizer.py

So the fast loaders cost you a pre-processing step (serialize once, deserialize fast forever) — a classic cache-build tradeoff. You pay setup cost so that every subsequent cold start is cheap, which is exactly the bargain you want when an autoscaler will create that replica hundreds of times over its life.

Distributed loading: don’t read what you don’t own

The single biggest structural win, and the place this chapter leans hardest on Chapter 15, is to not load the whole model on every GPU. Recall from Chapter 15 that under tensor parallelism each GPU (each “rank”) holds only a slice of every weight matrix — it never needs the other ranks’ slices to do its own share of the math. Yet the default loader makes every rank read the entire checkpoint and then throw away the parts it does not own, which means an 8-GPU group reads the same 140 GB eight times over. A sharded-state checkpoint fixes this by pre-splitting the weights into per-rank files, so rank k opens only the file holding rank k’s slice. The diagram below contrasts the two.

flowchart TD
    subgraph Naive["default: every rank reads everything"]
        F0["full checkpoint (140 GB)"] --> R0["rank 0 reads all, keeps 1/8"]
        F0 --> R1["rank 1 reads all, keeps 1/8"]
        F0 --> R2["rank ... reads all, keeps 1/8"]
        F0 --> R3["rank 7 reads all, keeps 1/8"]
    end
    subgraph Sharded["sharded_state: each rank reads only its slice"]
        S0["shard 0"] --> T0["rank 0"]
        S1["shard 1"] --> T1["rank 1"]
        S2["shard ..."] --> T2["rank ..."]
        S3["shard 7"] --> T3["rank 7"]
    end

A tensor-parallel rank only needs its shard; a sharded-state checkpoint lets each rank read just its slice:

# vllm/model_executor/model_loader/sharded_state_loader.py
class ShardedStateLoader(BaseModelLoader):
    """
    Model loader that directly loads each worker's model state dict, which
    enables a fast load path for large tensor-parallel models where each worker
    only needs to read its own shard rather than the entire checkpoint.
    """

Source: vllm/model_executor/model_loader/sharded_state_loader.py

For a TP-8 deployment (tensor parallelism across 8 GPUs) this cuts per-rank bytes-read by roughly 8x, and since all eight ranks read in parallel from independent files, wall-clock load time falls accordingly — the read is now eight times less data and eight-way parallel. Expert parallelism (EP, also from Chapter 15) extends the same idea to mixture-of-experts (MoE) weights, where the model has many “expert” sub-networks and each rank is responsible for only a subset of them. The default loader, when EP weight filtering is enabled, computes which experts a rank owns and refuses to read the rest from disk:

# vllm/model_executor/model_loader/default_loader.py
self.local_expert_ids = compute_local_expert_ids(
    num_experts,
    ep_size,
    ep_rank,
    placement=parallel_config.expert_placement_strategy,
)

The comment a few lines up is precise about the one case where this is unsafe — EPLB (Expert-Parallel Load Balancing, Chapter 15) dynamically reshuffles which physical GPU slot serves which logical expert to even out load, and it can place a redundant copy of an expert on a rank that the static ownership math says does not own it. If the loader had filtered that expert out, the rank would later be asked to run an expert whose weights it never read. So when EPLB is on, the filter is disabled and every rank loads every expert. Fast loading and dynamic expert rebalancing are in tension, and vLLM resolves it by choosing correctness over the load-time saving. That is the kind of coupling you only see when you read the code.

One more loader deserves a mention because it blurs the line with Chapter 12: the bitsandbytes path can quantize on the way in. The relevant comment in weight_utils.py, which sits just above the inflight-bitsandbytes branch, makes the mechanism explicit — online quantization ignores the checkpoint’s quant config and instead converts full-precision weights as they are read:

# vllm/model_executor/model_loader/weight_utils.py
# Online quantization doesn't read from checkpoint configs - it quantizes
# fp16/bf16 weights on the fly during loading.

Quantizing during load shrinks the bytes that land in HBM but adds CPU/GPU work to the load itself, so it trades a smaller resident footprint for a slower cold start. Whether that is a win depends entirely on whether you are memory-constrained or scale-up-latency-constrained — and now you can see those are different axes.

The process model you are scaling

Before autoscaling makes sense you need to know what a “replica” actually is, because it is not one process. vLLM splits the frontend (the FastAPI/OpenAI HTTP server that accepts requests) from the EngineCore (the GPU-bound process that runs the scheduler and the model worker), and they talk over ZMQ — ZeroMQ, a lightweight message-passing library — the split first seen in Chapter 11. Why split them? So the Python HTTP layer’s overhead never blocks the GPU loop, and so the two halves can have independent lifecycles. The diagram below shows the structure a fleet operator is actually scaling, from the single engine up to a data-parallel group behind one readiness bit.

flowchart TD
    LB["orchestrator / load balancer"] -->|"routes only when ready"| SUP["DP supervisor"]
    SUP -->|"spawn + poll /health"| C0["child 0: frontend + EngineCore (GPU 0)"]
    SUP -->|"spawn + poll /health"| C1["child 1: frontend + EngineCore (GPU 1)"]
    SUP -->|"spawn + poll /health"| C2["child N: frontend + EngineCore (GPU N)"]
    C0 -->|"ZMQ"| E0["scheduler + worker"]
    C1 -->|"ZMQ"| E1["scheduler + worker"]
    C2 -->|"ZMQ"| E2["scheduler + worker"]
    SUP -->|"all_healthy gate"| READY["group ready bit"]

The frontend’s lifecycle is in vllm/entrypoints/launcher.py, and the load-bearing detail for a fleet operator is graceful drain on shutdown:

# vllm/entrypoints/launcher.py
timeout = engine_client.vllm_config.shutdown_timeout
mode = "abort" if timeout == 0 else "drain"

Read that as: a shutdown_timeout of zero means “abort immediately,” and any positive value means “drain” — keep serving until either the in-flight requests finish or the timeout expires. Scale-down has its own correctness hazard: if you kill a replica that is mid-decode, you drop in-flight requests, and a decode that was forty tokens into a fifty-token answer is simply lost. A drain timeout lets running sequences finish before the process exits, which is the autoscaler-facing half of the lifecycle that people forget until they see a latency cliff during every scale-in event.

Above a single engine, data-parallel deployments use a supervisor to fan out one process per local GPU. (Data parallelism here means running independent full copies of the model, one per GPU, each serving its own requests — as opposed to the tensor parallelism above, where one model is split across GPUs.) vllm/entrypoints/openai/dp_supervisor.py spawns the children and, crucially, owns a readiness probe — a check that answers “is this process ready to receive traffic yet?”:

# vllm/entrypoints/openai/dp_supervisor.py
for local_rank in range(self.args.data_parallel_size_local):
    child_args = _build_vllm_dp_server_args(self.args, local_rank)
    child_env = _build_vllm_dp_server_env(self.args, local_rank)
    process = context.Process(
        target=_run_vllm_dp_server,
        name=f"APIServer_DPRank_{child_args.data_parallel_rank}",
        args=(child_args, child_env),
    )
    process.start()

The supervisor then polls each child’s /health and only declares the group ready once all children pass:

# vllm/entrypoints/openai/dp_supervisor.py
all_healthy = all(r is True for r in results)

if all_healthy:
    ...
    self._is_ready = True

That all_healthy gate is precisely the signal a Kubernetes readiness probe should wire to. (A readiness probe is the endpoint Kubernetes polls to decide whether to send a pod traffic; until it passes, the pod stays out of the load-balancer rotation.) The crucial design choice is the all — the group reports ready only when every child is, because routing a request to a child that is still loading would either fail or queue behind the cold start. The orchestration platform (K8s, llm-d, and friends) does not need to understand weight loading; it just needs to not route traffic until the slowest child has finished its cold start. The supervisor turns an N-GPU group into a single ready/not-ready bit, and the long tail of that bit — the reason it stays false for tens of seconds — is model load time. This is the precise point where the two halves of the chapter meet: the loader determines how long the bit stays false, and the autoscaler reacts to the bit. The autoscaler and the loader meet right here, at a health endpoint.

Autoscaling on signals the engine already emits

Here is where your existing instincts mostly transfer, with one substitution. You already know how to scale on a saturation signal. The only question is which signal, and vLLM hands you two from Chapter 2’s metric catalog. The first is queue depth:

# vllm/v1/metrics/loggers.py
gauge_scheduler_waiting = self._gauge_cls(
    name="vllm:num_requests_waiting",
    documentation="Number of requests waiting to be processed.",
    multiprocess_mode="mostrecent",
    labelnames=labelnames,
)

The second is how full the KV cache is:

# vllm/v1/metrics/loggers.py
gauge_kv_cache_usage = self._gauge_cls(
    name="vllm:kv_cache_usage_perc",
    ...
)

Why two? Because, as Chapter 2 insisted, “fast” is meaningless until you name the metric. num_requests_waiting rising means admission is backing up — requests are arriving faster than the scheduler can start them, so you are out of scheduling capacity. kv_cache_usage_perc near 1.0 means the paged KV cache from Chapter 4 is nearly full — you are out of memory — and from Chapter 6 you know what comes next when it stays there: preemption, the engine shedding load by evicting a running sequence’s cached state and recomputing it later when room frees up. A spike in vllm:num_preemptions is therefore not a warning that pain is coming; it is the system telling you it is already in pain. So a good policy watches KV usage as a leading indicator and queue depth as a confirming one: KV pressure predicts the preemption cliff before you go over it, and queue depth measures the SLO miss after you already have. Scaling on CPU utilization — the reflex from stateless services — tells you almost nothing here, because the work happens on the GPU, and a GPU pegged at 100% util can be either perfectly healthy (saturated and serving) or thrashing (preempting and recomputing). Utilization cannot distinguish the two; the KV and queue gauges can.

There is a subtlety that ties back to Chapter 18’s routing, and it is the crux of the whole chapter. Scaling is a feedback control loop: a signal is scraped, a controller decides, a replica boots, and capacity arrives — but every stage adds delay. The scrape is periodic, not instantaneous, and the signal is laggy relative to a per-step scheduler that makes decisions thousands of times a second. By the time num_requests_waiting has climbed high enough to trip a scale-up, and a new replica has paid its cold-start cost, the burst may already be over. The sequence diagram below lays the timeline out so the mismatch is visible: the new capacity arrives after the spike it was meant to absorb.

sequenceDiagram
    participant Traffic
    participant Metrics as metrics scrape
    participant Auto as autoscaler
    participant New as new replica
    Traffic->>Metrics: burst begins, queue climbs
    Note over Metrics: scrape interval delay
    Metrics->>Auto: num_requests_waiting high
    Auto->>New: scale-up, start cold replica
    Note over New: cold start, load weights (tens of s)
    Traffic->>Traffic: burst ends on its own
    New->>Traffic: ready, but too late

This is the fundamental autoscaling-for-LLMs problem the diagram makes concrete: your reaction time (tens of seconds of weight loading) is long compared to the burst duration. You cannot out-react a spike when reacting takes thirty seconds. The curves below put numbers on that mismatch: an incoming burst that arrives and recedes inside about a minute, against the capacity a reactive autoscaler actually delivers — flat through the scrape interval, then through the cold-start load, so the extra replica only comes online after the burst has already drained. The shaded gap between demand and capacity is exactly the unserved load that spills into queue depth and preemptions.

Illustrative: a synthetic burst against a capacity step that lags by a scrape interval plus a ~30 s cold start; shapes and timing are representative, not measured. The second replica (capacity jumps from 300 to 600 req/s near t = 55 s) lands after demand has already fallen back below 150 req/s.

The honest answers are all forms of not starting cold — that is, removing the cold-start latency from the critical path rather than trying to win a race against it. Keep a pool of warm-but-idle replicas sized to your burst variance, so a spike is absorbed by capacity that already exists (you pay for idle GPUs to buy latency). Keep weights in the page cache or an object store close to the GPU so the loader’s prefetch path is short. Scale from a snapshot rather than from scratch. And place models so that statistical multiplexing absorbs bursts without any scaling action at all — which is the contribution of AlpaServe (arXiv:2302.11665). Its insight, paraphrased: when you serve many models with bursty, uncorrelated demand, deliberately spreading each model across GPUs via model parallelism lets a burst for one model borrow idle capacity from others, so the right placement raises the load you can serve within an SLO without adding hardware. Read it for the framing that placement and multiplexing are a substitute for fast reaction, not a complement to it — exactly the lever you reach for when, as here, you cannot react fast enough.

What remains unsolved

None of this makes cold start free. The fast loaders shrink the constant; they do not change the fact that scale-up latency is bounded below by weight bytes over storage bandwidth,

$$t_{\text{load}} \geq \frac{\text{weight bytes}}{\text{storage bandwidth}},$$

and that bound grows with model size while autoscaler reaction-time budgets do not. The curves below plot that bound for bf16 checkpoints (2 bytes per parameter) across the three storage regimes this chapter discussed: fast local NVMe, a streamed object store, and an mmap-over-network filesystem doing small random reads. The 70B/140 GB/5 GB/s point from the opening paragraph sits where the NVMe curve crosses 28 seconds, and the same model on a network filesystem is already minutes away from serving its first token.

Illustrative: load times computed from the bytes/bandwidth bound using representative storage bandwidths; real loads add parsing and copy overhead, so these are lower bounds, not measurements.

That bound grows with model size while autoscaler reaction-time budgets do not. Scale-from-snapshot (forking a warm process’s GPU state) is promising and still rough at the edges. Right-sizing a warm pool is an open forecasting problem dressed up as a config value. And the laggy-signal problem means the most robust deployments today are the ones that scale least dynamically — generous warm pools and good placement, with reactive autoscaling as a backstop rather than the primary mechanism.

That closes the loop on the fleet. We can load a model fast, stand up a replica, route to it (Chapter 18), serve many adapters on it (Chapter 19), and grow or shrink the set on a signal. The one thing left is to see all of it well enough to know which knob to turn when a number moves the wrong way — which is the subject of the final chapter.

Further reading

  • AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving — arXiv:2302.11665 — placement and model-parallel multiplexing absorb bursty multi-model demand, a substitute for fast autoscaler reaction.

Observability and performance tuning: closing the loop

Everything in this book is a mechanism with a knob. The token-budget scheduler from Chapter 5 admits work against max_num_batched_tokens; preemption from Chapter 6 sheds load when the KV pool runs dry; prefix caching from Chapter 7 turns repeated context into a TTFT discount; the encoder budget from Chapter 8 competes for the same step; CUDA graphs from Chapter 10 amortize launch overhead but only for batch shapes that were captured; grammar masks from Chapter 14 add CPU work off the critical path, mostly. Every one of those mechanisms is invisible at the API boundary. A request goes in, tokens come out, and a latency number gets recorded. When that number regresses, you have a dozen suspects and one symptom.

This final chapter is about the bridge between the symptom and the suspect. You already know how to run a service: you watch percentiles, you alert on saturation, you correlate. What is different here is the vocabulary. An LLM engine emits signals that mean nothing unless you know the internals, and it hides the most useful ones inside a separate process. The skill is reading those signals well enough to say not “p99 TTFT is up” but “p99 TTFT is up because the prefix hit rate fell, because a tenant changed their system prompt, and the fix is a larger KV pool, not more replicas.” Localize to a mechanism, then turn the right knob.

That sentence describes a loop, and the loop is the spine of this chapter. You start at a symptom (a percentile moved), narrow it to one mechanism using the engine’s metrics, turn the single knob that controls that mechanism, then watch the same metrics to confirm the mechanism moved the way you predicted and that you did not shove the regression somewhere else. The diagram below traces that cycle, and the rest of the chapter fills in each box: what the metrics are, how to read them, and which knob each one points at.

flowchart LR
    S["symptom: a percentile regressed"] --> L["localize: read engine metrics, find the one mechanism responsible"]
    L --> K["turn the single knob for that mechanism"]
    K --> V["verify: watch the same metrics, did the mechanism move?"]
    V -->|"yes, and no new regression"| DONE["done, for now"]
    V -->|"no, or it broke something else"| L

The series catalog, and why the shape of each metric matters

vLLM’s Prometheus surface is defined in one place, vllm/v1/metrics/loggers.py, in the constructor of PrometheusStatLogger. It is worth knowing the taxonomy before the names, because the taxonomy tells you what each metric can and cannot answer. Prometheus, the metrics system vLLM exports to, has three fundamental kinds of time series, and each kind answers a different shape of question. A gauge is a single number that can go up or down, sampled at the moment of scrape; it answers “what is true right now.” A counter only ever increases, so its value is meaningless in isolation, but its slope over a window (computed by the rate(...) function) tells you how fast events are happening. A histogram is a set of counters, one per pre-defined bucket of values, that together let you reconstruct a distribution and ask for a percentile after the fact. Match the question to the kind and the metric reads itself; mismatch them and you will compute nonsense, like a percentile over a gauge or a “current value” of a counter. The diagram below sorts vLLM’s most diagnostic series into these three kinds, with the question each one answers.

flowchart TD
    M["vLLM metric series in loggers.py"] --> G["gauge: value now"]
    M --> C["counter: read as a rate"]
    M --> H["histogram: distribution, query percentiles"]
    G --> G1["num_requests_running / waiting: saturation"]
    G --> G2["kv_cache_usage_perc: how full is the KV pool"]
    C --> C1["num_preemptions: is the engine shedding load"]
    C --> C2["prefix_cache_hits / queries: reuse rate"]
    H --> H1["time_to_first_token_seconds: TTFT SLO"]
    H --> H2["inter_token_latency_seconds: ITL SLO"]
    H --> H3["queue / prefill / decode time: phase breakdown"]

Gauges are instantaneous. They answer “what is happening right now”:

# vllm/v1/metrics/loggers.py
gauge_scheduler_running = self._gauge_cls(
    name="vllm:num_requests_running",
    documentation="Number of requests in model execution batches.",
    multiprocess_mode="mostrecent",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

vllm:num_requests_running and vllm:num_requests_waiting are your saturation signals, the LLM analogue of an in-flight count and a queue depth: how many requests the engine is actively decoding this step versus how many are stuck behind them. vllm:kv_cache_usage_perc is the one with no analogue in a stateless service. It is the fraction of the paged KV pool from Chapter 4 that is currently allocated to live requests, and it, not CPU and not request count, is the resource that actually caps how many sequences can run at once. The reason is the autoregressive shape of the workload: every request holds KV blocks for as long as it is generating, and a long generation holds them for a long time, so concurrency is bounded by how much KV memory you have, not by how much compute. When this gauge sits near 1.0 the pool is nearly full, every block is spoken for, and the next request the scheduler tries to admit will find no free blocks, which forces it to evict a running request to make room. That eviction is preemption, and it is the bridge from this gauge to the counter in the next section.

Counters are monotonic; you read them as rates. The preemption counter is the single most diagnostic line in the whole file:

# vllm/v1/metrics/loggers.py
counter_num_preempted_reqs = self._counter_cls(
    name="vllm:num_preemptions",
    documentation="Cumulative number of preemption from the engine.",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

A nonzero rate(vllm:num_preemptions[1m]) means the engine is doing the load-shedding from Chapter 6: freeing a running request’s KV, resetting its computed-token count, and prepending it to the waiting queue to be recomputed later. That recompute is pure waste, and it shows up as ITL spikes for the victims. Preemption is not a bug; it is backpressure working as designed. But a sustained preemption rate means you have admitted more concurrency than the cache can hold, and no amount of routing cleverness upstream will fix it.

The prefix-cache counters come as a query/hit pair, deliberately counted in tokens, not requests:

# vllm/v1/metrics/loggers.py
counter_prefix_cache_queries = self._counter_cls(
    name="vllm:prefix_cache_queries",
    documentation=(
        "Prefix cache queries, in terms of number of queried tokens."
    ),
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

The hit rate is rate(prefix_cache_hits) / rate(prefix_cache_queries). Token-weighting matters: one request that reuses a 4000-token system prompt is worth far more than a hundred that share nothing, and a request-weighted rate would hide exactly the cache behavior you care about. A drop in this ratio is the most common silent TTFT regression in production, because the cause is upstream of the engine entirely: a routing change that stopped sending similar prompts to the same replica, or a tenant who appended a timestamp to their system prompt and busted every hash from Chapter 7.

Histograms are where the SLOs from Chapter 2 actually live. A histogram fixes a set of bucket boundaries up front, and every observation increments the counter for the bucket it falls into; at query time Prometheus interpolates across those buckets to estimate any percentile you ask for. The catch is that resolution lives entirely in the boundaries: a percentile can only be as precise as the buckets near it are dense. That is why the boundaries are hand-tuned rather than evenly spaced. TTFT and inter-token latency each get their own histogram:

# vllm/v1/metrics/loggers.py
histogram_time_to_first_token = self._histogram_cls(
    name="vllm:time_to_first_token_seconds",
    documentation="Histogram of time to first token in seconds.",
    buckets=[
        0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
        0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
        640.0, 2560.0,
    ],
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

Those buckets are dense from a millisecond to a second and sparse after, because that is the range where a TTFT SLO is decided and where a percentile query needs resolution. Past a few seconds the request is already failing its SLO and you do not need fine resolution to know it; before a millisecond nothing interesting happens. Plotting each boundary against its position in the list, as the curve below does on a log y-axis, makes the hand-tuning literal: the boundaries crawl from a millisecond to one second in the first twelve entries, then leap by factors of two-and-a-half to four per step out to 2560 seconds. Where the line is shallow the buckets are tight and a percentile is sharp; where it shoots up the buckets are coarse and a percentile there is only a rough bound.

The same constructor builds vllm:inter_token_latency_seconds for ITL.

The histograms that turn a latency regression into a diagnosis, though, are the per-request phase histograms: vllm:request_queue_time_seconds, vllm:request_prefill_time_seconds, and vllm:request_decode_time_seconds. Each one measures how long a request spent in one stage of its life. The engine-core process stamps a timestamp on a handful of lifecycle events as a request moves through it: queued_ts when the request first lands in the waiting queue, scheduled_ts when the scheduler first admits it into a batch, first_token_ts when the model emits its first output token, and last_token_ts when it emits its final one. The three intervals are just differences between consecutive stamps, computed in vllm/v1/metrics/stats.py:

# vllm/v1/metrics/stats.py
# Queued interval is from first QUEUED event to first SCHEDULED
queued_time = req_stats.scheduled_ts - req_stats.queued_ts
# Prefill interval is from first SCHEDULED to first NEW_TOKEN
prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts
# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
decode_time = req_stats.last_token_ts - req_stats.first_token_ts

The diagram below traces a single request through those four timestamps and shows which interval each pair of stamps defines. Read it left to right as the request’s own timeline; the brackets underneath are the three histograms.

sequenceDiagram
    participant Q as Waiting queue
    participant S as Scheduler
    participant M as Model on GPU
    Q->>S: queued_ts, request arrives
    Note over Q,S: queue_time = scheduled_ts - queued_ts
    S->>M: scheduled_ts, admitted to batch
    Note over S,M: prefill_time = first_token_ts - scheduled_ts
    M->>M: first_token_ts, first output token
    Note over M: decode_time = last_token_ts - first_token_ts
    M->>M: last_token_ts, final output token

This decomposition is the first cut of any latency investigation. End-to-end latency is up, but where did the time go: queue time (the request sat waiting because the engine was saturated and could not schedule it), prefill time (a long prompt, or a cold prefix cache that forced the engine to compute KV it could have reused), or decode time (the batch grew large and per-token latency suffered)? The three histograms answer that directly, and the answer matters because each phase maps one-to-one onto a different fix, which is exactly the knob map at the end of this chapter. One subtlety makes these intervals more honest than they look: by design the prefill and decode intervals absorb any preemption that happened during them, because preemption does not get its own timestamp, it just stretches the wall-clock gap between the surrounding events. So a request preempted mid-decode shows inflated decode time rather than a separate “preempted” bucket, which is exactly the cross-check you want, a decode-time histogram that suddenly grows a long tail should line up with a nonzero preemption counter.

Tokens by source: the prefill that wasn’t

One subtle metric deserves its own paragraph because it closes the loop between Chapters 7, 8, and 16. The scheduler does not just count prompt tokens; it attributes each one to where its KV came from. Recall that before a model can decode, every prompt token needs a key/value entry in the cache, and there are exactly three ways that entry can come to exist: the engine computed it on this GPU just now, it found it already sitting in the local paged cache from a previous request (a prefix-cache hit, Chapter 7), or it pulled it across a KV connector from somewhere else, an offload tier or a remote prefill node (Chapters 16 and 17). Those three sources are mutually exclusive and exhaustive, which is why PromptTokenStats in stats.py states the bookkeeping as an invariant in its own docstring:

# vllm/v1/metrics/stats.py
# Invariants:
#     computed + local_cache_hit + external_kv_transfer = total
#     local_cache_hit + external_kv_transfer = cached_tokens

The first line says every prompt token is accounted for by exactly one of the three sources; the second line says the two non-computed sources together are what we call “cached.” Exposed as vllm:prompt_tokens_by_source, this lets you watch the prefix cache do its job in absolute terms rather than as a ratio: how many prompt tokens were actually computed (local_compute) versus served from the local paged cache (local_cache_hit) versus pulled over a KV connector from the offload tier or a remote prefill (external_kv_transfer). The hit-rate ratio from the previous section tells you the fraction reused; this metric tells you the absolute count, which is what you need when you are deciding whether a change paid for itself. When you tune gpu_memory_utilization to enlarge the cache, this is the metric that tells you whether the extra blocks bought you reuse (local_cache_hit rises) or just sat idle (it does not). It is also the honest denominator for cost: GPU-seconds are spent computing KV, so they are spent on local_compute tokens and almost nothing else, and a healthy serving deployment drives that number down over time without dragging accuracy or hit rate down with it.

CUDA-graph fallback, hidden in plain sight

Chapter 10’s whole argument was that decode steps can replay as captured CUDA graphs, a recorded sequence of GPU kernel launches that the driver fires in one shot instead of the CPU dispatching each kernel individually, which erases per-step launch overhead. The catch is that a graph is captured for a specific batch size: a graph recorded for 32 sequences only knows how to run 32. So at every step the engine faces a batch of whatever size the scheduler produced and has to pick one of three outcomes. If the size exactly matches a captured graph, it replays it, the fast path. If it is smaller than a captured size, it can pad up, run the next-larger captured graph and waste compute on the padding rows. If it is larger than anything captured, it falls back to eager execution, dispatching kernels one at a time and paying exactly the launch overhead the graph was meant to erase. The two slow outcomes are real costs and easy to miss because no latency histogram names them, so vLLM records a per-step stat whose fields tell you exactly which case you hit:

# vllm/compilation/cuda_graph.py
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
    num_unpadded_tokens: int
    num_padded_tokens: int
    num_paddings: int
    runtime_mode: str

Source: vllm/compilation/cuda_graph.py

Those four fields are exactly the three-way decision made visible. num_unpadded_tokens is the batch the scheduler actually produced; num_padded_tokens is what it was rounded up to in order to hit a captured size; num_paddings is the difference, the wasted rows; and runtime_mode records which path ran. The diagram below is that per-step decision, and the stat is just a tally of which branch each step took.

flowchart TD
    B["decode step: batch of N sequences"] --> Q{"N matches a captured graph size?"}
    Q -->|"yes"| R["replay graph (fast path, num_paddings = 0)"]
    Q -->|"no, N smaller than a captured size"| P["pad up to next size, replay (num_paddings > 0)"]
    Q -->|"no, N larger than any captured size"| E["eager fallback (runtime_mode = eager, pay launch overhead)"]

Aggregated by CUDAGraphLogging into a frequency table over many steps, this is how you catch a regression that no latency histogram explains cleanly. If runtime_mode is frequently the eager fallback, your live batch sizes are landing outside the captured set, and you are paying the launch overhead the graph was supposed to erase. If num_paddings is large, you captured coarse batch sizes and are wasting compute padding small batches up to the next one. Both point at cudagraph_mode and the capture-size list, not at the scheduler. This metric is gated behind observability_config.cudagraph_metrics precisely because computing it every step is itself overhead; you turn it on when you suspect graph trouble and leave it off otherwise.

MFU: are you even near the roofline?

The deepest metric in the tree is the one that connects back to Chapter 3’s roofline. MFU, model FLOPs utilization, is the fraction of your accelerator’s peak arithmetic throughput that the model is actually using, achieved FLOP/s over peak FLOP/s; the bandwidth-utilization figure is its memory-side twin, the fraction of peak HBM bandwidth in use. Neither number is measured directly from hardware counters. Instead vllm/v1/metrics/perf.py carries an analytic model of the transformer: given the exact batch the scheduler ran, it computes from first principles how many floating-point operations that forward pass requires and how many bytes it must read and write per GPU, broken down by component (attn, ffn, unembed). Those estimates are exported as counters, and you turn them into utilization with a rate query. The header on the Prometheus class spells out the intended form:

# vllm/v1/metrics/perf.py
#   rate(vllm:estimated_flops_per_gpu_total[1m]) / 1e12
#
# Average memory bandwidth in GB/s can be calculated using:
#   (rate(vllm:estimated_read_bytes_per_gpu_total[1m]) +
#    rate(vllm:estimated_write_bytes_per_gpu_total[1m])) / 1e9

Divide the estimated FLOP/s by your accelerator’s peak FLOP/s and you get MFU; divide the byte rate by peak HBM bandwidth and you get bandwidth utilization, achieved bytes/s over peak bytes/s. Here the crucial part is that a low MFU is not automatically bad, and reading it correctly requires the prefill-versus-decode asymmetry from Chapter 3. Decode is memory-bound: each step reads the whole model and the KV cache to produce a single token per sequence, so there is very little arithmetic to do per byte moved, a low arithmetic intensity $I = \text{FLOPs} / \text{byte}$, and a healthy decode-heavy workload will sit near the bandwidth roofline while showing single-digit MFU. That low MFU is the expected, correct reading, not a problem to fix, because there are simply no FLOPs there to extract. The identical single-digit MFU in a prefill-heavy workload means the opposite. Prefill is compute-bound, so low MFU there says you are leaving arithmetic on the table, almost always because the batches are too small to keep the matrix units busy, and the fix is to admit more work per step. The roofline draws the line between the two regimes at the balance-point intensity $I_{\text{balance}}$, the FLOPs-per-byte at which the compute and bandwidth ceilings meet:

$$I_{\text{balance}} = \frac{\text{peak FLOP/s}}{\text{peak bytes/s}}$$

A workload with $I < I_{\text{balance}}$ is bandwidth-bound (decode), and one with $I > I_{\text{balance}}$ is compute-bound (prefill). Same number, opposite diagnosis, and the only way to tell them apart is to know which phase dominates. The roofline below draws both regimes on one log-log plot: the sloped left arm is the bandwidth ceiling (achievable FLOP/s rises with intensity because more arithmetic rides on each byte moved), the flat right arm is the compute ceiling (peak FLOP/s, you cannot go faster no matter the intensity), and the knee where they meet is $I_{\text{balance}}$. Decode lives far down the sloped arm at low intensity, hard against the bandwidth ceiling; prefill lives out on the flat arm where the only way up is bigger batches. A single-digit MFU on the left arm is the ceiling; the identical number on the right arm is wasted compute.

Illustrative: ceilings use round H100-class numbers (peak ~1.0 PFLOP/s, ~3.35 TB/s HBM, so $I_{\text{balance}}\approx 299$ FLOP/byte); the shape and the two-regime split are the point, not the exact coordinates. The model is approximate (it assumes perfect MoE expert load balance, for one, which Chapter 15 showed is a fiction under skew, so it overstates utilization exactly when routing is lopsided), and it is gated behind enable_mfu_metrics for the cost of computing it. But it is the closest thing the engine gives you to an answer to “are we using the hardware we are paying for.”

Scrape it, and mind which process you are profiling

The collection side is mundane and that is the point. The example in examples/observability/prometheus_grafana/ is a Prometheus job pointed at the server’s /metrics:

# examples/observability/prometheus_grafana/prometheus.yaml
global:
  scrape_interval: 5s
scrape_configs:
  - job_name: vllm
    static_configs:
      - targets:
          - 'host.docker.internal:8000'

Source: examples/observability/prometheus_grafana/prometheus.yaml

Two non-obvious things bite people here. First, scrape interval interacts with your bucket math: a 5-second scrape and a rate(...[1m]) window will smear short preemption bursts together, averaging a sharp spike down into a low plateau, so for incident forensics you want either a tighter rate window or the raw histogram. Second, and more important, recall the frontend/EngineCore process split from Chapter 11. vLLM does not run in one process. The HTTP server, request tokenization, and SSE response streaming live in the API-server (frontend) process; the scheduler, the model forward pass, and the GPU live in a separate EngineCore process, and the two talk over ZMQ, a fast inter-process message socket. That split is what makes the phase histograms so valuable for localization, because the boundary between “frontend cost” and “engine cost” runs straight through it. The diagram below shows the two processes, what each holds, and how a metric scrape versus a profiler attach sees them differently.

flowchart LR
    CLIENT["client"] -->|"HTTP"| FE
    subgraph FE["API-server process (frontend)"]
        HTTP["HTTP server"]
        TOK["tokenization"]
        SSE["SSE streaming / detokenization"]
    end
    FE <-->|"ZMQ"| EC
    subgraph EC["EngineCore process"]
        SCHED["scheduler"]
        MODEL["model forward pass"]
        GPU["GPU / KV cache"]
    end
    PROM["Prometheus /metrics"] -.->|"multiprocess mode: aggregates both"| FE
    PROM -.-> EC
    PROF["torch profiler"] -.->|"attaches to ONE pid only"| EC

The /metrics endpoint aggregates both processes via Prometheus multiprocess mode, so the numbers you scrape already span the whole engine. A profiler, though, attaches to a single process and sees only that one, which is why localization has to come first. When TTFT is slow but the phase histograms say prefill time is fine, the lost time is in the frontend, in detokenization or the egress path, and you profile the API-server process; when the phase histograms say prefill or decode is slow, the cost is in the engine, and you profile EngineCore.

Profiling the engine therefore means attaching to the right PID. vLLM exposes /start_profile and /stop_profile endpoints that the API server forwards over ZMQ to the EngineCore, which calls down into the worker:

# vllm/v1/engine/core.py
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
    self.model_executor.profile(is_start, profile_prefix)

Source: vllm/v1/engine/core.py

That call wraps a Torch profiler around the worker process, enabled by setting --profiler-config.torch_profiler_dir (so that both the frontend’s CPU trace and the worker’s combined CPU-and-GPU trace land in one directory for side-by-side viewing). Where a metric is a single aggregated number, a trace is a timeline of every kernel and every gap between kernels on the GPU, and that resolution is what lets you see whether a slow decode step is bound by the attention kernel itself, by the launch gaps that CUDA graphs were supposed to close, or by the grammar bitmask copy from Chapter 14 quietly landing on the critical path. The division of labor is clean: metrics localize a regression to a process and a mechanism, and the trace then localizes it to a specific kernel.

The knob-to-mechanism map

Once you have localized, tuning is almost mechanical, because each symptom has a primary knob and most of the knobs trade along the throughput-versus-latency axis from Chapter 2. The diagram below is that mapping as a decision tree: start at the symptom, read the two or three metrics that distinguish its causes, and arrive at the one knob to turn. The prose after it walks each branch and names the tradeoff you are accepting.

flowchart TD
    START["a percentile regressed; read the metrics"] --> W{"num_requests_waiting high?"}
    W -->|"yes, but kv usage moderate and preemptions zero"| K1["scheduler-throughput-limited: raise max-num-batched-tokens"]
    W -->|"yes, kv usage near 1.0 with steady preemptions"| K2["cache-limited: lower max-num-seqs or raise gpu-memory-utilization"]
    START --> T{"TTFT p99 dragged up by rare long prompts, hurting everyone's ITL?"}
    T -->|"yes"| K3["chunked-prefill interference: clamp long-prefill-token-threshold"]
    START --> MM{"multimodal: good throughput but erratic TTFT?"}
    MM -->|"yes"| K4["encoder budget contending: tune max-num-batched-tokens"]
    START --> CG{"CUDA-graph stats show eager fallback or heavy padding?"}
    CG -->|"yes"| K5["fix cudagraph_mode and capture sizes (not the scheduler)"]

If vllm:num_requests_waiting is high but kv_cache_usage_perc is moderate and preemptions are zero, you are scheduler-throughput-limited, not memory-limited: raise --max-num-batched-tokens to let more work into each step, accepting some ITL cost. If kv_cache_usage_perc rides near 1.0 with a steady preemption rate, you are cache-limited: lower --max-num-seqs to admit fewer concurrent sequences, or raise --gpu-memory-utilization to grow the pool if you have headroom. If TTFT p99 is dragged up by occasional very long prompts while ITL for everyone else suffers, that is the chunked-prefill interference from Chapter 6; clamp it with --long-prefill-token-threshold so a giant prefill is sliced thinner and shares each step more politely. If a multimodal deployment shows good token throughput but erratic TTFT, the encoder budget from Chapter 8 is contending with the token budget. In this version of vLLM that budget is not yet an independent flag, it is pinned to max_num_batched_tokens (max_num_encoder_input_tokens is marked not-currently-configurable in vllm/config/scheduler.py), so --max-num-batched-tokens is the lever you actually have, and you trade encoder headroom against decode ITL when you turn it. And if the CUDA-graph stats show frequent eager fallback or heavy padding, the fix is cudagraph_mode and the capture sizes, not anything in the scheduler.

The honesty clause: none of these knobs is independent. Raising max_num_batched_tokens helps prefill throughput and hurts decode ITL. Raising gpu_memory_utilization grows the cache but shrinks the activation headroom, and too far will OOM under a burst. Lowering max_num_seqs calms preemption but caps throughput. There is no setting that is good for every workload, which is why the loop is a loop: change one knob, watch the same metrics, confirm the mechanism moved the way you predicted, and check that you did not push the regression somewhere else.

What is still open

The biggest unsolved problem is that these signals are emitted per replica, and a fleet is many replicas behind a router. The router from Chapter 18 wants prefix-locality and queue-depth signals to make good decisions, and the block-storage events from Chapter 7 give it an approximate cache view, but turning per-replica metrics into a fleet-level control loop, autoscaling on num_requests_waiting and kv_cache_usage_perc together rather than CPU, as Chapter 20 began, remains more craft than science. Attribution across the connector boundary is genuinely hard: when external_kv_transfer tokens dominate, the latency is partly someone else’s prefill pool. And the analytic MFU model degrades exactly where serving gets interesting, under MoE skew, speculative decoding’s variable acceptance, and disaggregated phases that no single replica’s counters can see whole.

That is the real shape of the work. The engine will tell you almost everything, in a vocabulary this book has spent twenty chapters teaching you to read. The loop you close with it is never finished, only quieter.

Further reading

  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — the goodput framing and chunked-prefill tradeoff that every knob in this chapter is tuning against.
  • Orca: A Distributed Serving System for Transformer-Based Generative Models — OSDI ’22 — iteration-level batching, the source of the per-step scheduling signals you are reading.
  • DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving — arXiv:2401.09670 — why prefill and decode want different rooflines, and why one replica’s MFU cannot see the whole picture once they are split.

References

The papers cited across this book, aggregated from each chapter’s “Further reading” section and deduplicated. They are listed in order of arXiv identifier; the one venue-only citation is listed last.

  • The Curious Case of Neural Text Degeneration — arXiv:1904.09751
  • Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism — arXiv:1909.08053
  • Data Movement Is All You Need: A Case Study on Optimizing Transformers — arXiv:2007.00072
  • Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity — arXiv:2101.03961
  • GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers — arXiv:2210.17323
  • SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models — arXiv:2211.10438
  • Fast Inference from Transformers via Speculative Decoding — arXiv:2211.17192
  • AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving — arXiv:2302.11665
  • LLaVA: Visual Instruction Tuning — arXiv:2304.08485
  • AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration — arXiv:2306.00978
  • FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — arXiv:2307.08691
  • Efficient Memory Management for Large Language Model Serving with PagedAttention — arXiv:2309.06180
  • CacheGen: KV Cache Compression and Streaming for Fast Large Language Model Serving — arXiv:2310.07240
  • Punica: Multi-Tenant LoRA Serving — arXiv:2310.18547
  • S-LoRA: Serving Thousands of Concurrent LoRA Adapters — arXiv:2311.03285
  • Splitwise: Efficient Generative LLM Inference Using Phase Splitting — arXiv:2311.18677
  • SGLang: Efficient Execution of Structured LM Programs — arXiv:2312.07104
  • DistServe: Disaggregating Prefill and Decoding for Goodput-optimized Large Language Model Serving — arXiv:2401.09670
  • Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads — arXiv:2401.10774
  • EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty — arXiv:2401.15077
  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310
  • DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — arXiv:2405.04434
  • Preble: Efficient Distributed Prompt Scheduling for LLM Serving — arXiv:2407.00023
  • Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving — arXiv:2407.00079
  • Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution — arXiv:2409.12191
  • SuffixDecoding: Extreme Speculative Decoding for Emerging AI Applications — arXiv:2411.04975
  • XGrammar: Flexible and Efficient Structured Generation Engine for Large Language Models — arXiv:2411.15100
  • DeepSeek-V3 Technical Report — arXiv:2412.19437
  • Orca: A Distributed Serving System for Transformer-Based Generative Models — OSDI ’22