Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The hardware reality: memory-bound decode, compute-bound prefill, arithmetic intensity

Chapter 1 framed a request as a long-lived two-phase state machine, and Chapter 2 gave us the vocabulary to argue about it: TTFT for the first phase, ITL/TPOT for the second, throughput and goodput for the fleet, GPU-seconds for the bill. Those metrics described what we measure. This chapter explains why the two phases behave so differently that they need separate metrics in the first place. The answer is not in the model architecture or the scheduler. It is in the silicon, and specifically in one number that every serving engineer eventually learns to feel in their bones: the ratio of arithmetic operations to bytes moved.

Here is the claim, and it is worth stating bluntly because nearly every technique in the rest of this book is a corollary of it. Prefill is compute-bound. Decode is memory-bandwidth-bound. They are bottlenecked on different physical resources, and a serving stack that ignores this will leave one of those resources idle while the other saturates. If you internalize only one idea from Part I, make it this one.

The roofline, in one breath

You already reason about systems in terms of bottleneck resources: a proxy is CPU-bound here, network-bound there. GPUs have the same dichotomy, formalized as the roofline model. Two definitions first, because everything in this chapter hangs on them. A kernel is one unit of work the GPU runs — a matrix multiply, an attention computation, an element-wise add. Its arithmetic intensity is a property of that work: the number of floating-point operations it performs divided by the number of bytes it must read from memory to perform them. It is a ratio, $\text{FLOPs} / \text{byte}$, and it does not depend on how fast the hardware is. It only describes how much math the kernel squeezes out of each byte it fetches.

Now bring in the hardware. HBM (high-bandwidth memory) is the GPU’s main memory — the multi-gigabyte pool where weights and activations live. It is fast by RAM standards and glacial by compute standards, and that gap is the whole story. The chip has two hard ceilings: peak compute, how many FLOP/s the tensor cores can grind through, and peak memory bandwidth, how many bytes/s it can pull from HBM. Divide the first by the second and you get the machine balance point $I_{\text{balance}}$, the arithmetic intensity at which the two ceilings are reached simultaneously:

$$I_{\text{balance}} = \frac{\text{peak FLOP/s}}{\text{peak bytes/s}}$$

It is the break-even ratio: the number of FLOPs the hardware can do in the time it takes to fetch one byte.

The model is called a “roofline” because of the shape you get when you plot achievable performance against arithmetic intensity. To the left of the balance point the ceiling is a slanted line set by bandwidth; to the right it flattens into a horizontal line set by compute. A kernel whose arithmetic intensity sits to the left of the balance point finishes its handful of FLOPs long before the bytes it needs have arrived, so the tensor cores stall waiting on HBM — it is memory-bound, and the only way to go faster is to move fewer bytes. A kernel to the right has so much math per byte that the data arrives faster than the cores can consume it — it is compute-bound, and the only way to go faster is more FLOP/s. The single question that decides a kernel’s fate is which side of the balance point its arithmetic intensity lands on.

The curve below is that plot, on log-log axes for a representative datacenter GPU (~3 TB/s of HBM bandwidth, ~1000 TFLOP/s of peak compute, so a balance point near 333 FLOP/byte). The sloped left edge is the bandwidth ceiling, the flat right edge is the compute ceiling, and their corner is the machine balance. The two markers are the operating points this chapter is about: decode lands at the far left (arithmetic intensity ~1, pinned to the bandwidth ceiling with the tensor cores almost idle), while prefill, whose intensity grows with prompt length, sails past the balance point and rides the flat compute roof.

Illustrative: the ceiling and operating points use representative GPU specs to show the shape and the relative positions of the two phases, not measured kernel data.

The roofline above is the lens we apply to both phases. A modern datacenter GPU has a machine balance somewhere in the low hundreds of FLOP per byte. Hold that number; both phases get measured against it, and the rest of this chapter is the story of how prefill lands on the right of it and decode crashes to the far left.

Why prefill saturates compute

Prefill processes the whole prompt at once. A prompt of $S$ tokens flows through the model as an $S$-row activation matrix — each row is one token’s $d$-dimensional vector, where $d$ (the hidden size, a few thousand for a large model) is the width of the model’s internal representation. The heavy operations are the linear projections: a weight matrix of shape $(d \times d)$ multiplied against the $(S \times d)$ stack of token vectors. Count the work. The matmul performs on the order of $S \cdot d^2$ multiply-adds — one $d^2$ block of work for each of the $S$ rows. The bytes it must read are roughly $d^2$ weights (the weight matrix, loaded once) plus $S \cdot d$ activations (the input rows). For a large model and a prompt of any real length, the $S \cdot d^2$ arithmetic dwarfs the $d^2$ weight read.

Here is the key move, and it is what pulls prefill to the compute-bound side. You load the $(d \times d)$ weight matrix from HBM exactly once, then reuse it across all $S$ rows of the prompt. Every additional token in the prompt is more FLOPs spent against weights you have already paid to fetch. So arithmetic intensity scales with $S$: the longer the prompt, the more math per byte, and past a few dozen tokens it sails clean past the machine balance point. The GPU’s tensor cores run hot — this is the phase where a GPU does what marketing says it does: teraflops of dense linear algebra. This is also why TTFT scales with prompt length the way it does (more rows, more compute, and prefill is on the compute ceiling), and it is the physical reason Chapter 17 will argue for putting prefill on its own hardware pool.

Why decode starves on bandwidth

Now the autoregressive loop. After prefill, the model emits one token, appends it to the sequence, and runs another full forward pass to get the next — but this time the input is a single token, while the weights are still the entire model. That mismatch is the whole problem. The matmuls degrade from matrix-matrix to matrix-vector: instead of $(S \times d) @ (d \times d)$ you now compute $(1 \times d) @ (d \times d)$, a single row against the weight matrix. You read the same $d^2$ weights from HBM as before, but now you do only $d^2$ multiply-adds against that one row before moving on. There is no second, third, or $S$-th row to amortize the fetch over. Each weight is loaded from HBM, used exactly once, and discarded. Arithmetic intensity collapses to roughly 1 FLOP per byte — far below the machine balance — and so the tensor cores sit idle while the memory bus runs flat out. This is the same $(S \times d) @ (d \times d)$ projection as prefill; the only thing that changed is $S$ fell from “the whole prompt” to one.

The diagram below traces the two phases side by side, so the asymmetry is concrete: same weights, same kernel, opposite bottleneck.

flowchart LR
    subgraph P["prefill: S input rows"]
        direction TB
        P1["read (d x d) weights once"] --> P2["multiply against all S rows"]
        P2 --> P3["S x d^2 FLOPs per d^2 bytes read"]
        P3 --> P4["arithmetic intensity grows with S, compute-bound"]
    end
    subgraph D["decode: 1 input row"]
        direction TB
        D1["read (d x d) weights once"] --> D2["multiply against the single row"]
        D2 --> D3["d^2 FLOPs per d^2 bytes read"]
        D3 --> D4["arithmetic intensity ~1, memory-bound"]
    end

The brutal consequence, visible in the bottom branch of the diagram: a single decode step’s latency is, to first order, the time to stream the model’s weights through the memory bus once. Walk the floor explicitly. A 70B-parameter model in 16-bit precision (2 bytes per parameter) is ~140 GB of weights. Every decode step must read all of them at least once, because every layer’s projection needs its slice of those weights. On a GPU with ~3 TB/s of bandwidth, reading 140 GB once takes ~47 ms,

$$\frac{140 \text{ GB}}{3000 \text{ GB/s}} \approx 47 \text{ ms}$$

and that is a floor: even if the tensor cores were infinitely fast, the bus cannot deliver the weights faster, so no amount of spare compute can lower it. This is the single most important sentence for an infra engineer to absorb, because it reframes the entire problem. You are not compute-starved during the phase that dominates wall-clock time for long generations. You are bandwidth-starved, and the levers that matter are the ones that move fewer bytes.

And it is not only weights that the bus must carry. The KV cache from Chapter 1 — the stored keys and values for every prior token — must also be read on every decode step, and here is why it has to be. Attention works by having the new token compare itself against every earlier token (those are the keys) and then pull a weighted blend of what each earlier token carried (those are the values). So to generate token number $n$, the step must read the keys and values of all $n - 1$ tokens before it. That read grows with the length of the conversation, and it is paid again on every single step. Long contexts and large batches make the KV cache rival or exceed the weights in bytes moved per step — a second bandwidth bill stacked on top of the weight read, and one that, unlike the fixed weight cost, keeps growing as the sequence gets longer. That is exactly why Chapter 4 treats KV memory, not compute, as the binding constraint on concurrency, and why Chapter 12’s KV-cache quantization and Chapter 9’s MLA both attack the bytes of the cache rather than its FLOPs.

This asymmetry has a name in the literature worth reading: “Data Movement Is All You Need: A Case Study on Optimizing Transformers” (arXiv:2007.00072). The paper’s central finding is that transformer workloads are dominated not by arithmetic but by data movement — the operations that look cheap on a FLOP count (normalizations, element-wise ops, the plumbing between matmuls) are disproportionately expensive once you account for memory traffic, and operator fusion that keeps data in fast memory is where the wins live. The case study is on training a BERT encoder, but the methodology — classify every operator by its arithmetic intensity, then fuse the memory-bound ones — is exactly the roofline lens this chapter applies to inference; read it for how to measure where the bytes actually go.

Where the step actually happens in vLLM

Enough physics. The asymmetry is not an abstraction in vLLM — it is visible in the shape of the tensors that flow through one step. The forward pass itself is a single call in the model runner (vllm/v1/worker/gpu_model_runner.py):

        return self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **model_kwargs,
        )

Nothing here distinguishes prefill from decode. That is the point: the same code path runs both. What differs is the number of rows in input_ids. The scheduler hands the runner a flat batch of scheduled tokens, some of which happen to be a decode (one token for a running sequence) and some a prefill chunk (many tokens for a sequence catching up) — its own design note (vllm/v1/core/sched/scheduler.py) puts it bluntly, that “there’s no ‘decoding phase’ nor ‘prefill phase’ in the scheduler.” A batch heavy with prefill chunks has many rows and high arithmetic intensity; a batch of pure decodes has one row per sequence and craters into the memory-bound regime. The hardware reality and the scheduler’s uniformity are two views of the same thing — and Chapter 5 will make that uniform token-budgeted batch the central abstraction of the engine.

The decode-is-wasteful insight shows up most sharply in what happens after the forward pass. The model produces a hidden state for every input token, but a forward pass that processed $S$ prompt tokens still only needs to predict one next token — so only the final position’s hidden state is useful for sampling. vLLM gathers exactly those positions before computing logits (vllm/v1/worker/gpu_model_runner.py):

            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states)

And logits_indices is computed as the last token of each sequence in the batch (vllm/v1/worker/gpu_model_runner.py):

            logits_indices = query_start_loc[1:] - 1

Unpack that one-liner, because it is a tidy piece of indexing. query_start_loc is the prefix-sum of per-request token counts: if the batch holds requests of 5, 1, and 3 rows, it reads [0, 5, 6, 9], marking where each request’s rows begin in the flat batch. Drop the leading 0 ([1:] gives [5, 6, 9]) and subtract one ([4, 5, 8]) and you have the index of the last row of every request — exactly the position whose hidden state predicts that request’s next token. Everything else the prefill computed — the hidden states for all the interior prompt tokens — is thrown away for sampling purposes, because no token needs to be predicted from the middle of a prompt you already have. The comment a few lines up is candid about a related subtlety: a chunked-prefill request that has not finished its prompt produces a logit too, but the runner notes it will “ignore the sampled tokens from the partial requests.” This is the code-level confirmation of a fact this chapter has been circling: prefill produces no output token of its own. Its entire job is to populate the KV cache so that decode can begin. Chapter 11 will pick up this thread — the sampler operates only on [num_reqs, vocab] logits, one row per request — and explain why the only mandatory GPU-to-CPU sync per step is shipping back those sampled ids.

Speaking of which: the asymmetry even governs the one place the GPU must talk to the CPU. After sampling, the token ids cross back over the PCIe bus (vllm/v1/worker/gpu_model_runner.py):

                valid_sampled_token_ids = self._to_list(sampled_token_ids)

For a decode-dominated batch this transfer is tiny — one integer per sequence — yet it imposes a synchronization point that, left naive, would stall the bandwidth-bound GPU on a slow CPU round-trip. Chapter 10’s async scheduling exists precisely because, when each step is already as cheap as one weight-streaming pass, even a CPU hiccup becomes a visible fraction of ITL. You can already see the seam in the code: an alternate path caches the sampled ids on the GPU as prev_sampled_token_ids “to avoid CPU sync.” Hold that for Chapter 10.

A note on backends, and why this book stays on GPUs

vLLM is deliberately multi-backend. The same attention interface is implemented for CUDA (FlashAttention, FlashInfer), AMD’s ROCm (vllm/v1/attention/backends/rocm_attn.py), CPUs (vllm/v1/attention/backends/cpu_attn.py), Intel XPU, and TPU, selected at startup behind the dispatch seam Chapter 9 will dissect. The roofline argument is not GPU-specific — every accelerator has a compute ceiling and a bandwidth ceiling and therefore a balance point — but the numbers that make decode catastrophically memory-bound (hundreds of GB of weights versus single-token activations) are sharpest on high-bandwidth GPU memory, and the production fleet this book targets is GPU-based. So from here on the prose says “GPU” and “HBM” and means it, while you should remember the principle ports to any backend with the obvious substitutions.

The tradeoff this creates, and the one it does not solve

The asymmetry hands you a lever and a trap. The lever: decode wastes compute, so you can do more arithmetic per byte moved almost for free. Here is the mechanism that makes it free. A lone decode reads the whole weight matrix to multiply against one row. Put $B$ sequences’ decode steps in the same batch and the matmul becomes $(B \times d) @ (d \times d)$ — you still read each weight from HBM exactly once, but now you reuse it across all $B$ rows. The bytes moved barely change; the FLOPs multiply by $B$. Arithmetic intensity climbs back toward the balance point, and you have converted idle compute into served tokens at almost no extra bandwidth cost. The curve below shows that climb: decode arithmetic intensity rises roughly linearly with batch size, so a lone decode sits deep in the memory-bound regime while a batch of a few hundred reaches the balance point and finally puts the tensor cores to work. This is the throughput half of the throughput/latency tradeoff from Chapter 2, and the physical reason continuous batching (Chapter 5) is the foundational optimization. Speculative decoding (Chapter 13) is the same lever from another angle: verify several draft tokens in one weight-streaming pass, spending idle compute to amortize the memory read.

Illustrative: decode arithmetic intensity scales as roughly one FLOP/byte per batched sequence; the crossing point depends on the GPU’s balance point (here ~333) and is shown to convey the shape, not a measured kernel.

The trap: batching decode helps only until you run out of the other scarce resource. Every sequence you add to the batch carries its own KV cache, and KV cache is GPU memory — capacity, the third ceiling, distinct from both compute and bandwidth. So as you crank up batch size to chase throughput, the binding constraint silently migrates. The state diagram below traces the journey: a lone decode is bandwidth-bound; batch it and you climb back toward compute-bound, the regime you wanted; keep batching and the accumulated KV cache eventually fills HBM, and now you are capacity-bound and cannot admit another sequence at all. Which ceiling bites first depends on model size, context length, and load, and the path is not one-way — a batch of long-context decodes can be capacity-bound before it ever feels compute-bound. Naming that migration precisely is the job of the next chapter.

stateDiagram-v2
    [*] --> MemoryBound: single decode step
    MemoryBound --> ComputeBound: batch more sequences (reuse weights)
    ComputeBound --> CapacityBound: KV cache fills HBM
    CapacityBound --> MemoryBound: preempt or evict a sequence (Ch 6)
    MemoryBound: bandwidth-bound, bus saturated, cores idle
    ComputeBound: compute-bound, tensor cores busy, good throughput
    CapacityBound: capacity-bound, no HBM left for more KV cache

What the hardware does not hand you is a way to make a single decode step faster than one pass over the bytes it must read. You cannot out-schedule the speed of light on the memory bus. Every technique that genuinely lowers per-token latency, rather than raising throughput, works by reducing the bytes: quantize the weights so there are fewer to stream (Chapter 12), compress the KV cache so attention reads less (Chapters 9 and 12), or skip the read entirely for shared prefixes that are already resident (Chapter 7). Keep that taxonomy in hand. When a later chapter claims a speedup, the first question is always: did it add FLOPs to a bandwidth-bound step (cheap), or did it cut the bytes (the only thing that lowers the latency floor)?

We now know the shape of the problem and the physics underneath it. The next chapter follows the bytes to their largest and most dynamic source — the KV cache — and shows why treating it as paged memory, not one buffer per sequence, is what lets a single replica hold enough sequences in flight to keep that hungry memory bus busy.

Further reading

  • Data Movement Is All You Need: A Case Study on Optimizing Transformers — arXiv:2007.00072 — transformer workloads are dominated by data movement, not arithmetic; a BERT-training case study whose roofline/fusion methodology is the foundation for the memory-bound decode argument.