Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Chunked prefill, token budgets, and preemption as load-shedding

The previous chapter left us with a scheduler that re-plans the batch every step against a single number: a token budget. There is no prefill phase and no decode phase, only requests whose num_computed_tokens is chasing their num_tokens_with_spec, and a budget that says how many tokens the whole batch may advance this step. That reframing was the hard conceptual move. This chapter cashes it in.

Two problems remain unaddressed, and both are exactly the kind a traffic engineer is primed to see. The first is interference. A token budget tells you how much total work a step may contain, but it says nothing about how that work is distributed across requests. If one request shows up with a 30,000-token prompt and you let its prefill consume an entire step, every other request in the batch gets zero decode tokens that step. Their inter-token latency spikes. You have a head-of-line blocking problem dressed up as a prefill. The second problem is admission under pressure: the budget bounds compute per step, but the KV cache is a separate, finite resource, and nothing yet stops the scheduler from admitting more sequences than their growing caches will fit. When that happens, something has to give.

Chunked prefill answers the first. Preemption-with-recompute answers the second. They turn out to be the same machinery viewed from two angles, and both fall directly out of the token-budget loop you already understand.

Before diving in, a quick recap of the vocabulary this chapter leans on, all carried over from Chapter 5. A request’s num_computed_tokens is how many of its tokens the engine has already pushed through the model and written into the KV cache; num_tokens is how many it has in total (prompt plus tokens generated so far). A request is prefilling while num_computed_tokens is still climbing toward the end of its prompt, and decoding once it has caught up and is producing one new token per step. The token budget (max_num_batched_tokens) is the cap on how many tokens the whole batch may advance in a single schedule() step. The KV cache is a fixed pool of fixed-size blocks, each holding the keys and values for a handful of token positions; every token the engine computes must be written into a block, so the pool is a hard ceiling on how many tokens can be live at once. Keep those four ideas in hand and the rest of the chapter is bookkeeping on top of them.

Slicing the long prefill

Start with the interference problem, because its fix is almost embarrassingly small once the budget exists. A prefill of length $L$ is just $L$ tokens of work that all happen to belong to one request. The scheduler does not actually care that they belong to one request; it cares about the budget. So instead of scheduling all $L$ at once, it schedules $\min(L, \text{remaining budget})$ of them, advances num_computed_tokens by that much, and comes back for the rest on the next step. The prompt gets prefilled in chunks — a chunk being just one budget-sized slice of a prefill — and crucially those slices share each step with the decodes of every other running request.

The diagram below contrasts the two ways a 30,000-token prefill can hit the schedule. Without chunking, the long prefill monopolizes a step and the other requests’ decodes wait; with chunking, the same prefill is dribbled in a slice at a time so each step still has room for everyone’s decode token, keeping inter-token latency (ITL) smooth.

flowchart TD
    subgraph NoChunk["Without chunking (one step)"]
        A["30000-token prefill fills the whole step"] --> B["every other request's decode waits"]
        B --> C["inter-token latency spikes for the batch"]
    end
    subgraph Chunk["With chunking (many steps)"]
        D["step 1: prefill slice (budget) plus decodes"] --> E["step 2: next prefill slice plus decodes"]
        E --> F["step 3: next prefill slice plus decodes"]
        F --> G["prefill finishes, decodes never stalled"]
    end

Here is the clamp, in the WAITING-request path of the scheduler:

# vllm/v1/core/sched/scheduler.py
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
    num_new_tokens = threshold
...
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0

Source: vllm/v1/core/sched/scheduler.py

That min(num_new_tokens, token_budget) is the whole idea. Walk it left to right. num_new_tokens starts as everything the request still owes — request.num_tokens - num_computed_tokens, the gap between how long the request is and how much of it has been computed. The long_prefill_token_threshold clamp (we will come back to it) optionally tightens that, and then min(..., token_budget) caps it at whatever budget is left for this step. The result is the number of tokens this request gets to advance now. Whatever is left over simply stays uncomputed: the request stays in RUNNING, and the next schedule() call picks it up exactly where it left off, because the engine advances num_computed_tokens by the amount actually run after each step. So a long prefill makes progress one budget-sized bite per step, automatically, with no state machine tracking “which chunk are we on” — the chunk position is num_computed_tokens. There is no separate “chunked prefill” code path in any meaningful sense; chunking is just what the budget cap does to a prefill that is larger than the budget.

The one thing a flag controls is whether to allow it at all. enable_chunked_prefill is on by default, and when it is off the WAITING path simply breaks instead of slicing once a prompt exceeds the budget (the line elided by ... above), refusing to start any prefill it cannot finish in a single step — back to the all-or-nothing behavior the chunking fix was meant to replace. The running-request pass above it uses the identical pattern:

# vllm/v1/core/sched/scheduler.py
num_new_tokens = (
    request.num_tokens_with_spec
    + request.num_output_placeholders
    - request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
    num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)

Source: vllm/v1/core/sched/scheduler.py

A decode request asks for one token (its computed count is one behind its real length), gets it, and moves on. A prefilling request asks for thousands, gets a budget-sized chunk, and stays. Same three lines. The scheduler genuinely does not distinguish them, which is the payoff the previous chapter promised.

Notice the second clamp, the one gated on long_prefill_token_threshold. The budget cap alone is enough to chunk a prefill, but it is a blunt instrument: if the budget is large and mostly idle, a single huge prefill can still eat most of a step before any decode runs. long_prefill_token_threshold is a tighter, per-request ceiling specifically for prompts the operator considers “long.” Its default is 0, meaning disabled, and the 0 < guard is what makes that work: the clamp only fires when someone has set a positive threshold. When concurrent partial prefills are turned on, the config quietly derives one for you:

# vllm/config/scheduler.py
if self.max_num_partial_prefills > 1:
    if self.long_prefill_token_threshold == 0:
        self.long_prefill_token_threshold = int(max_model_len * 0.04)

Source: vllm/config/scheduler.py

Four percent of the model’s context length. The number is a heuristic, not a law. The point is that you now have two independent knobs over prefill granularity: the global max_num_batched_tokens budget that bounds the whole step, and long_prefill_token_threshold that bounds any one prefill within it. Tightening the threshold guarantees more room for decodes to ride alongside a long prefill, at the cost of stretching that prefill’s time-to-first-token across more steps. Which is exactly the tradeoff named in this part’s recurring theme: you are buying ITL stability for the batch by spending TTFT on the long request. There is no free chunk size; there is only where you want to sit on that curve. The curve below makes that frontier concrete for a 30,000-token prefill: each point is one chunk size, and moving toward smaller chunks slides you left (smoother batch ITL) and up (longer TTFT for the long prompt). No setting escapes the frontier; you only pick a seat on it.

Illustrative: the shape of the TTFT-vs-ITL frontier and its direction with chunk size are real; the millisecond values are representative, not measured. Point labels are the per-step prefill chunk size in tokens.

This is the mechanism that Sarathi-Serve formalized and that we flagged back in Chapter 2 when we first met the throughput-latency tension. The paper’s framing is stall-free batching: by admitting prefill work only in budgeted slices that fit alongside ongoing decodes, you stop prefills from stalling decodes, and the decode stream stays smooth under prefill-heavy load. vLLM’s scheduler is, in this respect, an implementation of that idea, with the budget cap doing the slicing and the threshold giving you a sharper knife when you want one.

When the cache, not the budget, runs out

Chunking handles compute interference. It does nothing for the other scarce resource. Every token you prefill or decode needs a KV slot, and slots come from a fixed pool of blocks. The token budget bounds how much you compute per step; it says nothing about whether the cumulative KV footprint of all running requests still fits. Under sustained load — many long sequences all decoding, all growing their caches one block at a time — the pool eventually cannot satisfy the next allocation. The engine has over-committed, and now it must shed.

The shedding happens inside the allocation loop. When the scheduler tries to give a running request its next chunk, it asks the KV cache manager for blocks (allocate_slots), and the manager returns None if there is not enough free memory to hold them. That None is the trigger that turns a memory shortage into an eviction:

# vllm/v1/core/sched/scheduler.py
while True:
    new_blocks = self.kv_cache_manager.allocate_slots(
        request,
        num_new_tokens,
        num_lookahead_tokens=self.num_lookahead_tokens,
    )
    if new_blocks is not None:
        # The request can be scheduled.
        break
    # The request cannot be scheduled.
    # Preempt the lowest-priority request.
    if self.policy == SchedulingPolicy.PRIORITY:
        preempted_req = max(
            self.running,
            key=lambda r: (r.priority, r.arrival_time),
        )
        self.running.remove(preempted_req)
        ...
    else:
        preempted_req = self.running.pop()
    self._preempt_request(preempted_req, scheduled_timestamp)
    preempted_reqs.append(preempted_req)
    if preempted_req == request:
        # No more request to preempt. Cannot schedule this request.
        break

Source: vllm/v1/core/sched/scheduler.py

Read this as a backpressure valve. The engine cannot make memory; the only thing it can do is reclaim it from a request that currently holds some. So when an allocation fails, it picks a victim, frees that victim’s blocks, and tries the allocation again — looping until either the allocation finally succeeds or there is no one left to evict. Under the default FCFS policy it evicts from the tail of the running list — self.running.pop() takes the most-recently-admitted, which is the least-senior request, a sensible victim choice that protects the requests closest to finishing. Under priority scheduling it instead picks the genuinely lowest-priority request (highest priority value, ties broken by arrival time). And there is a self-protecting terminal case: if the request the scheduler is trying to schedule is itself the lowest-priority thing in the batch, the victim picker can land on it, at which point preempted_req == request and the loop gives up. The engine would rather make no progress on this request than evict something more important in order to serve it.

The diagram below traces that loop: each failed allocation sheds one request and retries, and the loop has exactly two exits — the allocation succeeding, or the scheduler running out of anything cheaper to evict than the request it is working on.

flowchart TD
    A["need blocks for next chunk"] --> B["allocate_slots(request)"]
    B --> C{"got blocks?"}
    C -->|yes| D["schedule request: proceed"]
    C -->|no| E["pick victim (FCFS tail or lowest priority)"]
    E --> F["_preempt_request(victim): free its blocks"]
    F --> G{"victim is the request itself?"}
    G -->|no| B
    G -->|yes| H["give up: cannot schedule this request now"]

What does preemption actually do? This is the part that distinguishes LLM serving from a generic load-shedding story, so it is worth reading line by line:

# vllm/v1/core/sched/scheduler.py
def _preempt_request(self, request: Request, timestamp: float) -> None:
    assert request.status == RequestStatus.RUNNING, (
        "Only running requests can be preempted"
    )
    self.kv_cache_manager.free(request)
    self.encoder_cache_manager.free(request)
    self._inflight_prefills.discard(request)
    request.status = RequestStatus.PREEMPTED
    request.num_computed_tokens = 0
    if request.spec_token_ids:
        request.spec_token_ids = []
    request.num_preemptions += 1
    if self.log_stats:
        request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
    # Put the request back to the waiting queue.
    self.waiting.prepend_request(request)

Source: vllm/v1/core/sched/scheduler.py

Three things matter here. First, kv_cache_manager.free(request) returns every block the request held to the pool. That is the entire point — it is where the reclaimed memory comes from, the actual payload of the whole eviction. Second, request.num_computed_tokens = 0. The request does not pause; it is rewound to the very beginning. Every token it had already prefilled and every token it had already decoded is now marked uncomputed, so all of it must be recomputed from scratch when the request gets readmitted. This is preemption with recompute, and it is a deliberate design choice: rather than swapping the KV cache out to host memory and copying it back later (which costs PCIe bandwidth and a good deal of complexity), vLLM throws the cache away and pays to rebuild it. The bet is that recompute on a GPU is often cheaper than a round trip over PCIe, and that preemption is rare enough that the recompute cost stays in the noise. Third, self.waiting.prepend_request(request) puts the victim at the front of the waiting queue, not the back. A preempted request jumps the line ahead of never-started requests, because it has already waited once and we want it to resume as soon as memory frees up — which it just did, from this very preemption.

The diagram below follows a single request through that cycle. Note that preemption is not a terminal state: the request goes RUNNING to PREEMPTED to the front of WAITING and then back to RUNNING, losing its computed progress but keeping its place near the head of the line. A request can make this loop more than once under sustained pressure, which is exactly what num_preemptions counts.

stateDiagram-v2
    [*] --> WAITING: request arrives
    WAITING --> RUNNING: admitted, blocks allocated
    RUNNING --> RUNNING: prefill chunk or decode token
    RUNNING --> PREEMPTED: chosen as victim, free blocks, reset num_computed_tokens to 0
    PREEMPTED --> WAITING: prepended to front of queue
    RUNNING --> [*]: finished (EOS or max tokens)

The recompute cost is not entirely wasted, and this is where Chapter 7’s prefix caching quietly earns its keep ahead of schedule. When a preempted request is readmitted, its prompt blocks may still be sitting in the prefix cache (preemption frees the request’s references, but blocks linger until evicted), so the “recompute from zero” can land a long prefix-cache hit and skip most of the work it nominally threw away. The status counter num_preemptions is bumped on the way out, both for the metric we will get to and so that downstream prefix-cache accounting can tag this request as having been preempted — you saw preempted=request.num_preemptions > 0 threaded through the stats calls in the previous chapter’s code.

Admission control: not over-committing in the first place

Preemption is the valve of last resort. The cheaper move is to not over-admit, and the KV cache manager has two mechanisms for that. The first is the watermark — a slice of the pool the manager refuses to hand out to newly admitted requests:

# vllm/v1/core/kv_cache_manager.py
watermark_blocks = 0
# The watermark is applied to waiting/preempted requests only, and only
# when there's at least one request already scheduled.
if has_scheduled_reqs and request.status in (
    RequestStatus.WAITING,
    RequestStatus.PREEMPTED,
):
    watermark_blocks = self.watermark_blocks

Source: vllm/v1/core/kv_cache_manager.py

and the gate itself:

# vllm/v1/core/kv_cache_manager.py
available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks
required_blocks = num_blocks_to_allocate + watermark_blocks
if required_blocks > available_blocks:
    # Cannot allocate new blocks
    return None

Source: vllm/v1/core/kv_cache_manager.py

The asymmetry is the clever part. The watermark blocks are added to required_blocks — raising the bar an allocation must clear — only for requests in WAITING or PREEMPTED state, that is, requests being admitted into the running set, and only when something is already running. An already-running request decoding its next token faces a watermark_blocks of 0 and is never blocked by the reserve; it can dip into it freely. So the reserve is invisible to requests that already have a foot in the door and a wall to requests trying to get in. The effect is hysteresis: a deliberate gap between the memory level at which you stop admitting new work and the level at which you actually run out. You stop letting new sequences in while a band of memory is still free, which leaves headroom for the sequences you already committed to keep growing for a few more steps before they collide. Without it, the engine would admit greedily right up to the last free block and then thrash, preempting on nearly every step as running requests fight over the last slivers of memory. The watermark trades a little admitted concurrency for far fewer preemptions. Its default is 0.0, disabled, because the right reserve depends on your sequence-length distribution — but on memory-pressured deployments it is one of the first knobs to reach for when you see preemptions climbing.

The second mechanism is the full-sequence admission check, and it exists to plug a hole that chunking itself opened. Recall that admitting a request only allocates blocks for its first chunk. That is great for getting work started cheaply, but it means the admission decision is made against a tiny slice of the request’s eventual footprint: a 30,000-token prompt looks, at admission time, like a request needing only one budget-sized chunk’s worth of blocks. So the allocation succeeds even when the request’s full length could never fit in the pool. It sails in, decodes happily for a while as its cache grows chunk by chunk, and then preempts something (or gets preempted) once that cache outgrows the room that was actually available — a thrash that the engine could have foreseen at admission. The full_sequence_must_fit path closes that trap by sizing the admission check against the whole sequence up front rather than the first chunk:

# vllm/v1/core/kv_cache_manager.py
if full_sequence_must_fit:
    # First check and fail if the full request sequence won't fit.
    full_num_tokens = min(request.num_tokens, self.max_model_len)
    num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
        request_id=request.request_id,
        num_tokens=full_num_tokens,
        ...
        apply_admission_cap=True,
    )
    required_blocks = num_blocks_to_allocate + watermark_blocks
    if required_blocks > self.block_pool.get_num_free_blocks():
        return None

Source: vllm/v1/core/kv_cache_manager.py

The scheduler enables this by default via scheduler_reserve_full_isl, whose own docstring names the failure mode it prevents: “checking the first chunk… Prevents over-admission and KV cache thrashing with chunked prefill.” So the two halves of this chapter are coupled. Chunked prefill makes admission cheap to fake, and the full-sequence check makes sure the engine does not fool itself with that cheapness.

Watermark, full-sequence check, and preemption form a graded response to memory pressure — three escalating defenses ordered from cheapest and most preventive to most expensive and most reactive. The diagram below shows where each one sits relative to a new request’s allocation request. The first two run at admission time and cost nothing but some refused concurrency; preemption runs only when both have been bypassed (an already-running request, exempt from the watermark, grew until the pool was genuinely exhausted) and costs a full recompute. The design goal is to keep control near the top of this ladder and out of the expensive rung at the bottom.

flowchart TD
    A["request wants blocks"] --> B{"new admission and someone running?"}
    B -->|yes| C{"fits above the watermark reserve?"}
    B -->|no| F{"free blocks for this chunk?"}
    C -->|no| R1["reject: keep waiting (cheap, preventive)"]
    C -->|yes| D{"full sequence fits? (reserve_full_isl)"}
    D -->|no| R2["reject: would thrash later (cheap, preventive)"]
    D -->|yes| F
    F -->|yes| OK["allocate and run"]
    F -->|no| P["preempt a victim: free blocks, full recompute later (expensive)"]

Watching the valve

Because preemption is the symptom of memory pressure, it is the single most diagnostic counter the engine emits. Each _preempt_request records a PREEMPTED event; the stats layer turns each event into an increment:

# vllm/v1/metrics/stats.py
elif event.type == EngineCoreEventType.PREEMPTED:
    self.num_preempted_reqs += 1

Source: vllm/v1/metrics/stats.py

which the Prometheus logger publishes as a cumulative counter:

# vllm/v1/metrics/loggers.py
counter_num_preempted_reqs = self._counter_cls(
    name="vllm:num_preemptions",
    documentation="Cumulative number of preemption from the engine.",
    labelnames=labelnames,
)

Source: vllm/v1/metrics/loggers.py

A healthy replica’s vllm:num_preemptions is flat or nearly so. A rising slope means the engine is shedding load to survive, every preemption is paying a full recompute, and effective goodput is bleeding into wasted work. The operator’s responses, in roughly increasing cost: raise the watermark to admit less aggressively, lower max_num_seqs to cap concurrency, lower max_num_batched_tokens (which also shrinks chunks and eases the cache), or add capacity. We will return to this counter in Chapter 21 as one of the first signals to read when a replica’s latency regresses, because it localizes the problem precisely: a preemption spike is unambiguously a memory-admission problem, not a kernel or a routing one.

What is unsolved

Recompute-on-preempt is a pragmatic default, not a settled answer. Throwing away a long sequence’s KV cache and rebuilding it from scratch is genuinely expensive when the prompt is short and the generated suffix is long, since the suffix cannot be prefix-cached and must be regenerated token by token — the worst case for the recompute bet. Swapping the cache to host memory instead of discarding it trades that recompute for PCIe bandwidth, and which is cheaper depends on sequence shape, link speed, and how long the request stays evicted; vLLM picks recompute as the simpler primitive, but the choice is workload-dependent and not obviously right for every deployment.

Victim selection is similarly unsophisticated. FCFS evicts the youngest running request and priority evicts the lowest-priority one, but neither asks the question that would actually minimize wasted work: which request, if preempted, costs the least to rebuild? A request whose prompt is fully prefix-cached is nearly free to resume; one deep into an uncached generation is enormously expensive. A preemption policy that weighed recompute cost against seniority could shed load far more cheaply, and nothing in the current scheduler does. And the deepest tension remains the one this chapter only managed: the chunk-size choice that smooths ITL also lengthens TTFT for long prompts, and there is no setting that escapes the curve — only settings that move you along it. Sarathi-Serve quantifies that frontier; it does not abolish it.

The natural next question is whether all this recompute is necessary at all. Much of what a preempted request rebuilds, and much of what every fresh request prefills, is shared — system prompts, few-shot exemplars, conversation history that recurs across requests. If the idle KV pool could double as a content-addressed cache, the engine could pay for a shared prefix exactly once and hand it to every request that needs it, turning the most expensive part of prefill into a lookup. That is the subject of the next chapter, and it is the single largest TTFT lever in real traffic.

Further reading

  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — Introduces chunked-prefill and stall-free batching; the formal account of the budget-slicing mechanism this chapter grounds in vLLM’s scheduler, and the source for the TTFT-vs-ITL frontier you are tuning against.