vLLM
vLLM is the de-facto open-source LLM inference engine: it owns everything from the HTTP socket down to the CUDA kernel for a single model replica. In a production serving stack it sits exactly where an application server sits behind your load balancers — Envoy/gateway-api-inference-extension or a router like sglang-router/Dynamo picks a replica, and vLLM turns each request into scheduled GPU work: continuous batching, paged KV cache with prefix reuse, grammar-constrained sampling, speculative decoding, and streaming detokenization back out over SSE. This guide maps the concepts you already know to the actual code in the V1 engine (vllm/v1/), which is the only engine in current builds (the legacy V0 codepaths are gone; “V1” survives as the package name). Repo state as read: commit 3b03a2cf4772 (June 2026).
Why you care
- It is a traffic system in miniature. The scheduler is a token-level rate limiter with a strict budget per tick (
max_num_batched_tokens), a priority/FCFS admission queue, watermark-based admission control, and preemption-as-loadshedding. Every intuition you have about queue depth, head-of-line blocking, and backpressure applies one-for-one — just with “tokens/step” instead of “bytes/sec”. - Every router you’ll work on terminates here. gateway-api-inference-extension’s Envoy endpoint picker, sglang-router, Dynamo, and llm-d all make routing decisions from signals vLLM emits:
vllm:num_requests_waiting,vllm:kv_cache_usage_perc, prefix-cache hit events, and KV-transfer handshakes for P/D disaggregation. Knowing what those numbers mean inside the engine is what separates “moves traffic” from “owns inference infra”. - The process architecture is a microservice. Frontend (asyncio, tokenization, SSE egress) and EngineCore (scheduler + GPU loop) are separate OS processes joined by ZMQ + msgpack — a deliberate “data plane vs control plane” split so Python GIL work on egress never stalls the GPU step loop. You have debugged this exact shape at 500TB/day.
Architecture map
| Path | What it is |
|---|---|
vllm/entrypoints/openai/ | FastAPI app, OpenAI-compatible routes (chat_completion/, completion/, responses/), CLI args |
vllm/entrypoints/cli/ | vllm serve, vllm bench {serve,latency,throughput} subcommands |
vllm/v1/engine/async_llm.py | AsyncLLM — frontend engine client; per-request asyncio generators + background output handler |
vllm/v1/engine/core.py | EngineCore (schedule→execute→output step) and EngineCoreProc (busy loop in its own process) |
vllm/v1/engine/core_client.py | ZMQ/msgpack RPC between frontend and EngineCore (MPClient/AsyncMPClient) |
vllm/v1/engine/output_processor.py, detokenizer.py | EngineCoreOutputs → RequestOutput; incremental detokenization, stop strings |
vllm/v1/core/sched/scheduler.py | The scheduler: continuous batching, chunked prefill, preemption, KV-connector hooks |
vllm/v1/core/kv_cache_manager.py, block_pool.py, kv_cache_utils.py | Paged KV allocator, prefix-cache hash table, LRU free queue |
vllm/v1/executor/ | Fans SchedulerOutput to worker processes (multiproc/Ray), collects ModelRunnerOutput |
vllm/v1/worker/gpu_worker.py, gpu_model_runner.py | Per-GPU execution: input batch assembly, forward, sampling, D2H bookkeeping (gpu/ is the experimental runner V2) |
vllm/v1/attention/backends/ | Attention impls: flash_attn.py, flashinfer.py, triton_attn.py, mla/ — selected via registry.py |
vllm/v1/sample/ | GPU Sampler, top-k/top-p ops (optional FlashInfer kernels), rejection sampler for spec decode |
vllm/v1/structured_output/ | Grammar manager + backends (backend_xgrammar.py, guidance, outlines) |
vllm/v1/spec_decode/ | EAGLE / ngram draft proposers |
vllm/v1/metrics/loggers.py, prometheus.py | All vllm:* Prometheus series |
vllm/compilation/ | torch.compile integration: graph splitting, piecewise backend, CUDAGraphWrapper |
vllm/model_executor/ | Model definitions and layers; the attention custom op lives here |
vllm/distributed/kv_transfer/ | KV connectors for P/D disaggregation (NIXL, mooncake, …) |
csrc/ | CUDA/C++ kernels: cache ops, MoE, quant; classic PagedAttention in csrc/libtorch_stable/attention/ |
Core mechanisms
1. V1 engine flow: AsyncLLM → EngineCore → tokens out
Two processes. The frontend (AsyncLLM) tokenizes and validates, then ships an EngineCoreRequest over ZMQ to the EngineCore process, which runs a busy loop. Comments in vllm/v1/engine/core_client.py:464:
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
EngineCore runs in a background process busy loop, getting
new EngineCoreRequests and returning EngineCoreOutputs
* pushes EngineCoreRequests via input_socket
* pulls EngineCoreOutputs via output_socket
"""
The heart of the engine is one synchronous step — schedule, launch forward, overlap CPU grammar work with the GPU, sample, then reconcile (vllm/v1/engine/core.py:443):
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output. ..."""
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
...
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
...
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
Note the overlap trick: execute_model(..., non_block=True) launches the forward pass, then the CPU fills grammar bitmasks while the GPU runs, and sample_tokens(grammar_output) finishes the step. The busy loop wrapping this is vllm/v1/engine/core.py:1223 (run_busy_loop: poll input queue → _process_engine_step). With async_scheduling there is a deeper pipelined variant, step_with_batch_queue at vllm/v1/engine/core.py:484, which schedules step N+1 before step N’s sample completes.
On the way back, a single background asyncio task per frontend (AsyncLLM._run_output_handler, vllm/v1/engine/async_llm.py:637) pulls EngineCoreOutputs batches off the ZMQ socket, runs the output processor in bounded chunks (VLLM_V1_OUTPUT_PROC_CHUNK_SIZE) so it never hogs the event loop, and pushes RequestOutputs into per-request queues. Each generate() call is just a consumer of its own queue (vllm/v1/engine/async_llm.py:576):
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, RequestOutput)
finished = out.finished
if out is not STREAM_FINISHED:
yield out
Client disconnects surface as asyncio.CancelledError here and trigger an abort RPC back into the EngineCore — the request lifecycle is fully bidirectional, like stream resets in gRPC.
2. Scheduler: continuous batching as a token budget
The scheduler discards the classic prefill/decode dichotomy entirely. The design note at vllm/v1/core/sched/scheduler.py:355:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
Every step gets a budget token_budget = self.max_num_scheduled_tokens (vllm/v1/core/sched/scheduler.py:373, i.e. --max-num-batched-tokens, default 2048 per vllm/config/scheduler.py:42). Pass 1 walks RUNNING requests (decodes ask for 1 token, mid-prefill requests ask for the remainder) and clamps each ask (vllm/v1/core/sched/scheduler.py:416):
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
Pass 2 admits WAITING requests only if nothing was preempted this step (vllm/v1/core/sched/scheduler.py:579), checks the prefix cache for each (get_computed_blocks), and implements chunked prefill as a simple min against the remaining budget (vllm/v1/core/sched/scheduler.py:741): num_new_tokens = request.num_tokens - num_computed_tokens then num_new_tokens = min(num_new_tokens, token_budget). A prompt that doesn’t fit this step simply continues next step — chunking falls out of the bookkeeping for free. Decodes and prefill chunks ride in the same batch; there is no separate prefill queue.
Preemption is the backpressure valve: if allocate_slots returns None (no free KV blocks), the scheduler evicts the lowest-priority / most-recently-arrived running request and retries in a loop (vllm/v1/core/sched/scheduler.py:474-518). Preemption is brutal and simple (vllm/v1/core/sched/scheduler.py:1033):
def _preempt_request(self, request: Request, timestamp: float) -> None:
...
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
...
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
...
# Put the request back to the waiting queue.
self.waiting.prepend_request(request)
All KV is dropped and the prompt re-prefills from scratch on readmission (prefix cache hits soften the cost). After the GPU returns, update_from_output (vllm/v1/core/sched/scheduler.py:1388) appends sampled tokens, handles spec-decode rejections, and emits EngineCoreOutputs. The two queues self.waiting/self.running plus this per-step re-planning is continuous batching.
3. KV cache: block pool, prefix caching, eviction
KVCacheManager (vllm/v1/core/kv_cache_manager.py:110) fronts a BlockPool of fixed-size blocks (16 tokens default). Prefix caching is a chained content hash: each request eagerly hashes its tokens into per-block hashes where every hash commits to the whole prefix (vllm/v1/core/kv_cache_utils.py:563):
def hash_block_tokens(
hash_function: Callable[[Any], bytes],
parent_block_hash: BlockHash | None,
curr_block_token_ids: Sequence[int],
extra_keys: tuple[Any, ...] | None = None,
) -> BlockHash:
...
if not parent_block_hash:
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash(
hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys))
)
extra_keys folds in things that change KV content beyond token ids — multimodal hashes, LoRA, cache_salt (the hasher is built in vllm/v1/core/kv_cache_utils.py:659; NONE_HASH is seeded from os.urandom at :111 unless pinned — cross-worker determinism for distributed prefix lookups). On admission, lookup is one longest-prefix-match walk (vllm/v1/core/kv_cache_manager.py:227); note the “must recompute the last token to get logits” cap at :221. Cache hits bump refcounts and rescue blocks from the free queue (vllm/v1/core/block_pool.py:402):
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
Eviction is lazy LRU: freed blocks keep their hash and stay in cached_block_hash_to_block while sitting in FreeKVCacheBlockQueue — an intrusive doubly-linked list with O(1) middle-removal, ordered LRU-first then tail-of-chain-first (vllm/v1/core/kv_cache_utils.py:165). Only when a block is reallocated does _maybe_evict_cached_block strip its hash from the cache map (vllm/v1/core/block_pool.py:365) — i.e., the entire idle KV pool doubles as prefix cache. Blocks become reusable as soon as they fill: allocate_slots commits finalized full blocks every step (vllm/v1/core/kv_cache_manager.py:452), so a long generation’s prefix is shareable while it is still generating. Admission control uses watermark headroom (:363-370 — waiting requests can’t starve running ones), and full_sequence_must_fit / reserved_blocks (:278-286) gate chunked-prefill over-admission and async KV-connector loads — connection-pool-style reservation logic.
4. Structured output: CPU bitmask, GPU mask
Grammar compilation (xgrammar by default) happens off the hot path in a thread pool when a request arrives (StructuredOutputManager.grammar_init, vllm/v1/structured_output/__init__.py:115); until compilation finishes the request is parked in state WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR. Each step, while the GPU runs the forward pass (see step() above), the scheduler calls grammar_bitmask() (vllm/v1/structured_output/__init__.py:204) which has each request’s FSM matcher write one bit per vocab token into a shared CPU tensor (vllm/v1/structured_output/__init__.py:186):
def _fill_bitmasks(
self, batch: Iterable[tuple[StructuredOutputGrammar, int, bool]]
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
self._grammar_bitmask[index].fill_(self._full_mask)
fill_bitmask is a direct xgrammar call: self.matcher.fill_next_token_bitmask(bitmask, idx) (vllm/v1/structured_output/backend_xgrammar.py:191). Large batches are sharded across a thread pool in groups of 16 (:239-265); with spec decode the FSM is advanced per draft token then rolled back (:278-291). The tensor ships to workers as numpy because “that is much more efficient for serialization” (:297-300).
On the worker, after logits exist, the mask is reordered to batch order, copied H2D non_blocking, and applied in-place by xgrammar’s GPU kernel — invalid tokens become -inf before sampling (vllm/v1/structured_output/utils.py:100):
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
...
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
Call site: vllm/v1/worker/gpu_model_runner.py:4410, inside sample_tokens() right before self._sample(...) at :4415. So the per-step loop is exactly the one you know: CPU computes constraint → GPU applies → GPU samples → CPU advances FSM with the sampled token (accept_tokens, vllm/v1/structured_output/backend_xgrammar.py:148, called from the scheduler in update_from_output).
5. Sampler: what runs on GPU, what crosses back
Sampler (vllm/v1/sample/sampler.py:20) is an nn.Module over the final-position logits only ([num_reqs, vocab] — logits_indices gathers the last token per request before the LM head, so prefill chunks produce no logits at all). The ordered pipeline is documented at :21-58: logprobs snapshot → fp32 → allowed-tokens/bad-words → min-tokens & logit-bias processors → penalties → temperature → min-p → top-k/top-p → sample. The forward core (vllm/v1/sample/sampler.py:95):
# Use float32 for the logits.
logits = logits.to(torch.float32)
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
# Sample the next token.
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
...
sampled = sampled.long()
Top-k/top-p uses FlashInfer’s sorting-free rejection-sampling kernels when available (vllm/v1/sample/ops/topk_topp_sampler.py:70, flashinfer_sample at :471), else a native PyTorch path. Greedy requests short-circuit to argmax. The only data that must cross GPU→CPU each step is the sampled token ids (plus optional logprobs): _bookkeeping_sync (vllm/v1/worker/gpu_model_runner.py:3557) does valid_sampled_token_ids = self._to_list(sampled_token_ids) (:3616) via a pinned-memory buffer + dedicated copy stream. With async_scheduling even that sync is dodged: token ids stay on-GPU as prev_sampled_token_ids and are spliced into the next step’s input buffer (:3636-3642), letting the CPU schedule step N+1 blind. Sampled ids are not sent back over ZMQ redundantly — the model runner caches them and the scheduler tracks counts (:3649-3653).
6. CUDA graphs + piecewise torch.compile
Decode steps are tiny (one token per sequence), so kernel-launch overhead dominates; vLLM amortizes it by replaying pre-captured CUDA graphs. The wrinkle: attention with a growing paged KV cache is awkward to capture. The solution is piecewise capture — torch.compile traces the model, then the FX graph is split at attention ops (split_graph, vllm/compilation/backends.py:548), with the splitting-op list defined at vllm/config/compilation.py:745:
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
"vllm::unified_attention_with_output",
"vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2",
...
]
Everything between attentions (MLP, norms, projections — shape-static given a padded token count) gets compiled and captured as cudagraphs by PiecewiseBackend (vllm/compilation/piecewise_backend.py:86); attention runs eager in the gaps. At runtime, CudagraphDispatcher.dispatch (vllm/v1/cudagraph_dispatcher.py:239) pads the batch up to the nearest captured size and selects FULL (pure uniform decode, attention captured too) vs PIECEWISE vs NONE; modes enumerated at vllm/config/compilation.py:53 (FULL_AND_PIECEWISE = decode uses full graphs, mixed prefill batches use piecewise). The wrapper itself is cache-keyed on BatchDescriptor (vllm/compilation/cuda_graph.py:233):
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
...
cudagraph = torch.cuda.CUDAGraph()
...
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
…and replay is just entry.cudagraph.replay() (vllm/compilation/cuda_graph.py:360). Replay requires identical input addresses, so the runner owns persistent input buffers and copies each step’s data into them; debug mode asserts data_ptr() equality (:346-355).
7. Attention kernels and dispatch
Models never call a backend directly — every attention layer calls a registered PyTorch custom op which torch.compile treats as opaque (this is what makes graph splitting possible), vllm/model_executor/layers/attention/attention.py:734:
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: LayerNameType,
...
) -> None:
...
layer_name = _resolve_layer_name(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self, query, key, value, kv_cache, attn_metadata, output=output, ...)
(registered via direct_register_custom_op at :777). The impl is chosen at startup by get_attn_backend (vllm/v1/attention/selector.py:54) → platform logic → AttentionBackendEnum (vllm/v1/attention/backends/registry.py:34), which maps names to classes: FLASH_ATTN → vllm/v1/attention/backends/flash_attn.py, FLASHINFER → flashinfer.py, TRITON_ATTN, plus a whole mla/ family for DeepSeek-style latent attention. Override with --attention-backend FLASHINFER (vllm/engine/arg_utils.py:907).
The FlashAttention backend is the default on NVIDIA. Its forward (vllm/v1/attention/backends/flash_attn.py:698) first appends this step’s K/V into the paged cache (reshape_and_cache_flash, :919), then runs flash_attn_varlen_func (:839) over ragged sequences — block_table arg makes it paged-aware, so “PagedAttention” today is a feature of the FA kernel rather than a separate kernel. The original PagedAttention v1/v2 CUDA kernels still exist at csrc/libtorch_stable/attention/paged_attention_v1.cu / paged_attention_v2.cu but are off the main path. Worth reading for your performance instincts, vllm/v1/attention/backends/flash_attn.py:740:
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
8. Streaming egress: OpenAI server, detokenization, stop strings
vllm serve builds a FastAPI app (vllm/entrypoints/openai/api_server.py); /v1/chat/completions lives at vllm/entrypoints/openai/chat_completion/api_router.py:40 and returns StreamingResponse(content=generator, media_type="text/event-stream") (:74). The SSE generator (chat_completion_stream_generator, vllm/entrypoints/openai/chat_completion/serving.py:398) iterates AsyncLLM.generate() and emits chat.completion.chunk deltas, running tool-call/reasoning parsers incrementally per choice.
Underneath, OutputProcessor.process_outputs (vllm/v1/engine/output_processor.py:576) is deliberately “the only function that should loop over EngineCoreOutputs” per batch. Detokenization is incremental and stateful per request: the fast path wraps the HF tokenizers Rust DecodeStream (FastIncrementalDetokenizer, vllm/v1/engine/detokenizer.py:167), primed with the prompt ids so byte-level merges across the prompt/output boundary decode correctly. Stop-string handling is the subtle part — text, not tokens: a stop string can span token boundaries, so the detokenizer withholds max(len(stop)) - 1 chars from the stream (vllm/v1/engine/detokenizer.py:84):
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.stop_buffer_length = 0
check_stop_strings (vllm/v1/engine/detokenizer.py:309) scans only the newly decoded chars (plus the overlap window) and truncates output_text at the match. Because stop strings are detected in the frontend while the EngineCore is already generating the next tokens, the frontend fires an explicit abort back upstream (vllm/v1/engine/output_processor.py:678-681: “If req not finished in EngineCore, but Detokenizer detected stop string, abort needed in EngineCore”) — a classic distributed cancellation race, handled exactly like client-disconnect aborts. Prometheus is mounted as an ASGI sub-app at /metrics (vllm/entrypoints/serve/instrumentator/metrics.py:41); the series definitions are all in vllm/v1/metrics/loggers.py (vllm:num_requests_running :452, vllm:num_requests_waiting :462, vllm:kv_cache_usage_perc :520, vllm:prefix_cache_queries/hits :543/:554, vllm:num_preemptions :620).
Suggested reading path
vllm/v1/engine/core.py— readEngineCore.__init__,step()(:443), andEngineCoreProc.run_busy_loop(:1223) first; everything else hangs off this loop.vllm/v1/core/sched/scheduler.py—schedule()(:353) top to bottom, thenupdate_from_output(:1388). The single most important file in the repo.vllm/v1/core/kv_cache_manager.py—get_computed_blocks(:202) andallocate_slots(:244), including the block-layout ASCII diagram at:290.vllm/v1/core/block_pool.py+vllm/v1/core/kv_cache_utils.py— hash table, free-queue LRU,hash_block_tokens(kv_cache_utils.py:563).vllm/v1/worker/gpu_model_runner.py— skimexecute_model(:4000),sample_tokens(:4379),_bookkeeping_sync(:3557); this is where scheduler abstractions become tensors.vllm/v1/engine/async_llm.py—generate()(:524) and_run_output_handler(:637) for the frontend half; glance atvllm/v1/engine/core_client.py:464for the ZMQ seam.vllm/v1/engine/output_processor.py+vllm/v1/engine/detokenizer.py— egress, stop strings, the abort race.vllm/v1/structured_output/__init__.py+vllm/v1/structured_output/utils.py— the CPU/GPU bitmask handshake.vllm/compilation/cuda_graph.py+vllm/v1/cudagraph_dispatcher.py— capture/replay and dispatch keys.vllm/v1/attention/backends/flash_attn.py— one full backend end-to-end: metadata builder, cache write, varlen kernel call.
Bonus: vllm/v1/core/sched/output.py (SchedulerOutput — the exact wire contract between scheduler and workers) and vllm/v1/request.py (per-request state machine).
Connections to your study set
- nano-vllm — a ~1.2k-line reimplementation of exactly the files above: its scheduler ≈
scheduler.pyminus chunked prefill/connectors, its block manager ≈kv_cache_manager.py+block_pool.pywith the same parent-chained block hashing. Read nano-vllm first or alongside; every concept there has a 10x-more-edge-cases twin here. - sglang — the main competing engine. Same continuous batching and paged KV, but prefix caching is a radix tree over token sequences (RadixAttention) instead of vLLM’s flat hash-of-full-blocks map — compare with
block_pool.py’scached_block_hash_to_block. Its Rustsgl-routerdoes cache-aware routing by approximating each replica’s radix tree; vLLM’s equivalent signal is the--kv-eventsstream (BlockStored/BlockRemovedpublished fromvllm/v1/core/block_pool.py:392) plus prefix-cache metrics. - dynamo — NVIDIA’s Rust distributed layer above engines: KV-aware routing, P/D disaggregation orchestration. It plugs into vLLM through the KV-connector API you saw threaded through the scheduler (
self.connector.get_num_new_matched_tokensatvllm/v1/core/sched/scheduler.py:674,WAITING_FOR_REMOTE_KVSstate at:866) with NIXL as the transfer fabric (vllm/distributed/kv_transfer/kv_connector/v1/). - llm-d — K8s-native distributed inference built around vLLM replicas; its scheduler/EPP consumes the same
/metricsseries defined invllm/v1/metrics/loggers.pyfor load- and cache-aware placement. - gateway-api-inference-extension — your Envoy world: an ext-proc endpoint picker that scrapes
vllm:num_requests_waiting,vllm:kv_cache_usage_perc, and LoRA-adapter metrics from each pod to pick a backend. The chat router even accepts anendpoint-load-metrics-formatheader (vllm/entrypoints/openai/chat_completion/api_router.py:54) so the gateway can get load data inline with responses. - xgrammar — the default structured-output backend; vLLM’s wrapper is
vllm/v1/structured_output/backend_xgrammar.py(compile →fill_next_token_bitmaskon CPU →apply_token_bitmask_inplaceon GPU). vLLM’s Triton port of xgrammar’s apply kernel lives in the experimental runner atvllm/v1/worker/gpu/structured_outputs.py:86. - flashinfer — optional kernel library used in two places: as a full attention backend (
vllm/v1/attention/backends/flashinfer.py, good paged/cascade decode kernels) and for sorting-free top-k/top-p sampling (vllm/v1/sample/ops/topk_topp_sampler.py:471).
Tinkering on one RTX 5080 (16GB)
Blackwell (SM120) needs a recent CUDA 12.8+ wheel — uv pip install vllm --torch-backend=auto handles it. Good fits in 16GB: Qwen/Qwen3-4B-Instruct-2507 (BF16, roomy KV), Qwen/Qwen3-8B-FP8, or RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16. Keep --max-model-len 8192 to start; watch the startup log line that reports how many GPU KV blocks fit.
vllm serve Qwen/Qwen3-4B-Instruct-2507 \
--max-model-len 8192 --gpu-memory-utilization 0.90
- Watch the engine think.
VLLM_LOGGING_LEVEL=DEBUG vllm serve ...showsEngineCore waiting for work, per-shapeCapturing a cudagraph (PIECEWISE, ...)lines at startup, and scheduler activity — map each line to the code above. - Prefix caching A/B. Send 50 requests sharing a 2-4k-token system prompt, then rerun with
--no-enable-prefix-caching. Compare TTFT andvllm:prefix_cache_queriesvsvllm:prefix_cache_hitson/metrics. UsePOST /reset_prefix_cachebetween trials for clean runs; add"cache_salt"to a request body to deliberately fork the hash chain (kv_cache_utils.pyextra_keys). - Chunked prefill / budget pressure. Sweep
--max-num-batched-tokens 512 2048 8192while one client streams a decode-heavy chat and another fires 6k-token prompts. At 512, watch a single prefill get sliced across many steps (decode ITL stays smooth); at 8192, ITL spikes whenever a prefill lands. Then cap with--long-prefill-token-threshold 256and watch the spike flatten — this is themin(num_new_tokens, token_budget)line doing its job. - max-num-seqs sweep + forced preemption.
vllm bench serveagainst--max-num-seqs 16/64/128/256; find where throughput saturates and ITL inflates (KV-bound, not compute-bound). Then set--gpu-memory-utilization 0.5 --max-model-len 16384and long generations to starve the block pool:vllm:num_preemptionsticks up and you’ll see_preempt_requestrecompute-from-zero behavior in tail latencies. - Prometheus to dashboard.
examples/observability/prometheus_grafana/ships aprometheus.yamlscrape config andgrafana.jsondashboard (TTFT/ITL histograms, KV usage, queue depth). Wire it up and reproduce the queue-depth → kv-usage → preemption causality chain during experiment 4 — this is the exact signal set the Envoy inference EPP routes on. - Real traces.
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.jsonthen:
vllm bench serve --model Qwen/Qwen3-4B-Instruct-2507 \
--dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
--request-rate 8 --num-prompts 500
(implementation: vllm/benchmarks/serve.py, dataset loaders in vllm/benchmarks/datasets/datasets.py:1340). Sweep --request-rate to find the knee; compare goodput under Poisson arrivals vs --burstiness.
7. CUDA graph ablation. Run the same decode benchmark three ways: default, --compilation-config '{"cudagraph_mode": "NONE"}', and --enforce-eager (no compile at all). Single-stream decode ITL makes launch overhead visible; on a small model expect a 20-40% gap eager vs graphs.
8. Structured output cost. Fire the same prompts with and without response_format={"type": "json_schema", ...}. Per-step overhead is the CPU fill_next_token_bitmask + H2D mask copy from section 4; it shows up as ITL delta that grows with batch size, and you can profile the grammar_bitmask call inside EngineCore.step with py-spy dump on the EngineCore process — a nice demo that the frontend and engine really are separate PIDs.