Preface
A companion to the Inference Serving Roadmap: eight source-level study guides that map the concepts of LLM inference serving onto the actual code of the open-source stack — vLLM, a minimal teaching engine (nano-vllm), SGLang, FlashInfer, XGrammar, NVIDIA Dynamo, llm-d, and the Gateway API Inference Extension.
Each guide is written for an engineer who already understands large-scale traffic systems and wants to find where routing, batching, KV-cache management, and disaggregation actually live in real repositories. Code is quoted from each project under its own (permissive) license with file-path attribution; research papers are cited as jumping-off points, not reproduced.
Read vLLM first for the reference design, nano-vllm to see the same ideas in a few hundred lines, then branch out by interest.
vLLM
vLLM is the de-facto open-source LLM inference engine: it owns everything from the HTTP socket down to the CUDA kernel for a single model replica. In a production serving stack it sits exactly where an application server sits behind your load balancers — Envoy/gateway-api-inference-extension or a router like sglang-router/Dynamo picks a replica, and vLLM turns each request into scheduled GPU work: continuous batching, paged KV cache with prefix reuse, grammar-constrained sampling, speculative decoding, and streaming detokenization back out over SSE. This guide maps the concepts you already know to the actual code in the V1 engine (vllm/v1/), which is the only engine in current builds (the legacy V0 codepaths are gone; “V1” survives as the package name). Repo state as read: commit 3b03a2cf4772 (June 2026).
Why you care
- It is a traffic system in miniature. The scheduler is a token-level rate limiter with a strict budget per tick (
max_num_batched_tokens), a priority/FCFS admission queue, watermark-based admission control, and preemption-as-loadshedding. Every intuition you have about queue depth, head-of-line blocking, and backpressure applies one-for-one — just with “tokens/step” instead of “bytes/sec”. - Every router you’ll work on terminates here. gateway-api-inference-extension’s Envoy endpoint picker, sglang-router, Dynamo, and llm-d all make routing decisions from signals vLLM emits:
vllm:num_requests_waiting,vllm:kv_cache_usage_perc, prefix-cache hit events, and KV-transfer handshakes for P/D disaggregation. Knowing what those numbers mean inside the engine is what separates “moves traffic” from “owns inference infra”. - The process architecture is a microservice. Frontend (asyncio, tokenization, SSE egress) and EngineCore (scheduler + GPU loop) are separate OS processes joined by ZMQ + msgpack — a deliberate “data plane vs control plane” split so Python GIL work on egress never stalls the GPU step loop. You have debugged this exact shape at 500TB/day.
Architecture map
| Path | What it is |
|---|---|
vllm/entrypoints/openai/ | FastAPI app, OpenAI-compatible routes (chat_completion/, completion/, responses/), CLI args |
vllm/entrypoints/cli/ | vllm serve, vllm bench {serve,latency,throughput} subcommands |
vllm/v1/engine/async_llm.py | AsyncLLM — frontend engine client; per-request asyncio generators + background output handler |
vllm/v1/engine/core.py | EngineCore (schedule→execute→output step) and EngineCoreProc (busy loop in its own process) |
vllm/v1/engine/core_client.py | ZMQ/msgpack RPC between frontend and EngineCore (MPClient/AsyncMPClient) |
vllm/v1/engine/output_processor.py, detokenizer.py | EngineCoreOutputs → RequestOutput; incremental detokenization, stop strings |
vllm/v1/core/sched/scheduler.py | The scheduler: continuous batching, chunked prefill, preemption, KV-connector hooks |
vllm/v1/core/kv_cache_manager.py, block_pool.py, kv_cache_utils.py | Paged KV allocator, prefix-cache hash table, LRU free queue |
vllm/v1/executor/ | Fans SchedulerOutput to worker processes (multiproc/Ray), collects ModelRunnerOutput |
vllm/v1/worker/gpu_worker.py, gpu_model_runner.py | Per-GPU execution: input batch assembly, forward, sampling, D2H bookkeeping (gpu/ is the experimental runner V2) |
vllm/v1/attention/backends/ | Attention impls: flash_attn.py, flashinfer.py, triton_attn.py, mla/ — selected via registry.py |
vllm/v1/sample/ | GPU Sampler, top-k/top-p ops (optional FlashInfer kernels), rejection sampler for spec decode |
vllm/v1/structured_output/ | Grammar manager + backends (backend_xgrammar.py, guidance, outlines) |
vllm/v1/spec_decode/ | EAGLE / ngram draft proposers |
vllm/v1/metrics/loggers.py, prometheus.py | All vllm:* Prometheus series |
vllm/compilation/ | torch.compile integration: graph splitting, piecewise backend, CUDAGraphWrapper |
vllm/model_executor/ | Model definitions and layers; the attention custom op lives here |
vllm/distributed/kv_transfer/ | KV connectors for P/D disaggregation (NIXL, mooncake, …) |
csrc/ | CUDA/C++ kernels: cache ops, MoE, quant; classic PagedAttention in csrc/libtorch_stable/attention/ |
Core mechanisms
1. V1 engine flow: AsyncLLM → EngineCore → tokens out
Two processes. The frontend (AsyncLLM) tokenizes and validates, then ships an EngineCoreRequest over ZMQ to the EngineCore process, which runs a busy loop. Comments in vllm/v1/engine/core_client.py:464:
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
EngineCore runs in a background process busy loop, getting
new EngineCoreRequests and returning EngineCoreOutputs
* pushes EngineCoreRequests via input_socket
* pulls EngineCoreOutputs via output_socket
"""
The heart of the engine is one synchronous step — schedule, launch forward, overlap CPU grammar work with the GPU, sample, then reconcile (vllm/v1/engine/core.py:443):
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output. ..."""
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
...
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
...
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
Note the overlap trick: execute_model(..., non_block=True) launches the forward pass, then the CPU fills grammar bitmasks while the GPU runs, and sample_tokens(grammar_output) finishes the step. The busy loop wrapping this is vllm/v1/engine/core.py:1223 (run_busy_loop: poll input queue → _process_engine_step). With async_scheduling there is a deeper pipelined variant, step_with_batch_queue at vllm/v1/engine/core.py:484, which schedules step N+1 before step N’s sample completes.
On the way back, a single background asyncio task per frontend (AsyncLLM._run_output_handler, vllm/v1/engine/async_llm.py:637) pulls EngineCoreOutputs batches off the ZMQ socket, runs the output processor in bounded chunks (VLLM_V1_OUTPUT_PROC_CHUNK_SIZE) so it never hogs the event loop, and pushes RequestOutputs into per-request queues. Each generate() call is just a consumer of its own queue (vllm/v1/engine/async_llm.py:576):
finished = False
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() or await q.get()
assert isinstance(out, RequestOutput)
finished = out.finished
if out is not STREAM_FINISHED:
yield out
Client disconnects surface as asyncio.CancelledError here and trigger an abort RPC back into the EngineCore — the request lifecycle is fully bidirectional, like stream resets in gRPC.
2. Scheduler: continuous batching as a token budget
The scheduler discards the classic prefill/decode dichotomy entirely. The design note at vllm/v1/core/sched/scheduler.py:355:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
Every step gets a budget token_budget = self.max_num_scheduled_tokens (vllm/v1/core/sched/scheduler.py:373, i.e. --max-num-batched-tokens, default 2048 per vllm/config/scheduler.py:42). Pass 1 walks RUNNING requests (decodes ask for 1 token, mid-prefill requests ask for the remainder) and clamps each ask (vllm/v1/core/sched/scheduler.py:416):
num_new_tokens = (
request.num_tokens_with_spec
+ request.num_output_placeholders
- request.num_computed_tokens
)
if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens:
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
num_new_tokens = min(num_new_tokens, token_budget)
Pass 2 admits WAITING requests only if nothing was preempted this step (vllm/v1/core/sched/scheduler.py:579), checks the prefix cache for each (get_computed_blocks), and implements chunked prefill as a simple min against the remaining budget (vllm/v1/core/sched/scheduler.py:741): num_new_tokens = request.num_tokens - num_computed_tokens then num_new_tokens = min(num_new_tokens, token_budget). A prompt that doesn’t fit this step simply continues next step — chunking falls out of the bookkeeping for free. Decodes and prefill chunks ride in the same batch; there is no separate prefill queue.
Preemption is the backpressure valve: if allocate_slots returns None (no free KV blocks), the scheduler evicts the lowest-priority / most-recently-arrived running request and retries in a loop (vllm/v1/core/sched/scheduler.py:474-518). Preemption is brutal and simple (vllm/v1/core/sched/scheduler.py:1033):
def _preempt_request(self, request: Request, timestamp: float) -> None:
...
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
...
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
...
# Put the request back to the waiting queue.
self.waiting.prepend_request(request)
All KV is dropped and the prompt re-prefills from scratch on readmission (prefix cache hits soften the cost). After the GPU returns, update_from_output (vllm/v1/core/sched/scheduler.py:1388) appends sampled tokens, handles spec-decode rejections, and emits EngineCoreOutputs. The two queues self.waiting/self.running plus this per-step re-planning is continuous batching.
3. KV cache: block pool, prefix caching, eviction
KVCacheManager (vllm/v1/core/kv_cache_manager.py:110) fronts a BlockPool of fixed-size blocks (16 tokens default). Prefix caching is a chained content hash: each request eagerly hashes its tokens into per-block hashes where every hash commits to the whole prefix (vllm/v1/core/kv_cache_utils.py:563):
def hash_block_tokens(
hash_function: Callable[[Any], bytes],
parent_block_hash: BlockHash | None,
curr_block_token_ids: Sequence[int],
extra_keys: tuple[Any, ...] | None = None,
) -> BlockHash:
...
if not parent_block_hash:
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash(
hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys))
)
extra_keys folds in things that change KV content beyond token ids — multimodal hashes, LoRA, cache_salt (the hasher is built in vllm/v1/core/kv_cache_utils.py:659; NONE_HASH is seeded from os.urandom at :111 unless pinned — cross-worker determinism for distributed prefix lookups). On admission, lookup is one longest-prefix-match walk (vllm/v1/core/kv_cache_manager.py:227); note the “must recompute the last token to get logits” cap at :221. Cache hits bump refcounts and rescue blocks from the free queue (vllm/v1/core/block_pool.py:402):
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
Eviction is lazy LRU: freed blocks keep their hash and stay in cached_block_hash_to_block while sitting in FreeKVCacheBlockQueue — an intrusive doubly-linked list with O(1) middle-removal, ordered LRU-first then tail-of-chain-first (vllm/v1/core/kv_cache_utils.py:165). Only when a block is reallocated does _maybe_evict_cached_block strip its hash from the cache map (vllm/v1/core/block_pool.py:365) — i.e., the entire idle KV pool doubles as prefix cache. Blocks become reusable as soon as they fill: allocate_slots commits finalized full blocks every step (vllm/v1/core/kv_cache_manager.py:452), so a long generation’s prefix is shareable while it is still generating. Admission control uses watermark headroom (:363-370 — waiting requests can’t starve running ones), and full_sequence_must_fit / reserved_blocks (:278-286) gate chunked-prefill over-admission and async KV-connector loads — connection-pool-style reservation logic.
4. Structured output: CPU bitmask, GPU mask
Grammar compilation (xgrammar by default) happens off the hot path in a thread pool when a request arrives (StructuredOutputManager.grammar_init, vllm/v1/structured_output/__init__.py:115); until compilation finishes the request is parked in state WAITING_FOR_STRUCTURED_OUTPUT_GRAMMAR. Each step, while the GPU runs the forward pass (see step() above), the scheduler calls grammar_bitmask() (vllm/v1/structured_output/__init__.py:204) which has each request’s FSM matcher write one bit per vocab token into a shared CPU tensor (vllm/v1/structured_output/__init__.py:186):
def _fill_bitmasks(
self, batch: Iterable[tuple[StructuredOutputGrammar, int, bool]]
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
self._grammar_bitmask[index].fill_(self._full_mask)
fill_bitmask is a direct xgrammar call: self.matcher.fill_next_token_bitmask(bitmask, idx) (vllm/v1/structured_output/backend_xgrammar.py:191). Large batches are sharded across a thread pool in groups of 16 (:239-265); with spec decode the FSM is advanced per draft token then rolled back (:278-291). The tensor ships to workers as numpy because “that is much more efficient for serialization” (:297-300).
On the worker, after logits exist, the mask is reordered to batch order, copied H2D non_blocking, and applied in-place by xgrammar’s GPU kernel — invalid tokens become -inf before sampling (vllm/v1/structured_output/utils.py:100):
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
...
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
Call site: vllm/v1/worker/gpu_model_runner.py:4410, inside sample_tokens() right before self._sample(...) at :4415. So the per-step loop is exactly the one you know: CPU computes constraint → GPU applies → GPU samples → CPU advances FSM with the sampled token (accept_tokens, vllm/v1/structured_output/backend_xgrammar.py:148, called from the scheduler in update_from_output).
5. Sampler: what runs on GPU, what crosses back
Sampler (vllm/v1/sample/sampler.py:20) is an nn.Module over the final-position logits only ([num_reqs, vocab] — logits_indices gathers the last token per request before the LM head, so prefill chunks produce no logits at all). The ordered pipeline is documented at :21-58: logprobs snapshot → fp32 → allowed-tokens/bad-words → min-tokens & logit-bias processors → penalties → temperature → min-p → top-k/top-p → sample. The forward core (vllm/v1/sample/sampler.py:95):
# Use float32 for the logits.
logits = logits.to(torch.float32)
logits = self.apply_logits_processors(
logits, sampling_metadata, predict_bonus_token
)
# Sample the next token.
sampled, processed_logprobs = self.sample(logits, sampling_metadata)
...
sampled = sampled.long()
Top-k/top-p uses FlashInfer’s sorting-free rejection-sampling kernels when available (vllm/v1/sample/ops/topk_topp_sampler.py:70, flashinfer_sample at :471), else a native PyTorch path. Greedy requests short-circuit to argmax. The only data that must cross GPU→CPU each step is the sampled token ids (plus optional logprobs): _bookkeeping_sync (vllm/v1/worker/gpu_model_runner.py:3557) does valid_sampled_token_ids = self._to_list(sampled_token_ids) (:3616) via a pinned-memory buffer + dedicated copy stream. With async_scheduling even that sync is dodged: token ids stay on-GPU as prev_sampled_token_ids and are spliced into the next step’s input buffer (:3636-3642), letting the CPU schedule step N+1 blind. Sampled ids are not sent back over ZMQ redundantly — the model runner caches them and the scheduler tracks counts (:3649-3653).
6. CUDA graphs + piecewise torch.compile
Decode steps are tiny (one token per sequence), so kernel-launch overhead dominates; vLLM amortizes it by replaying pre-captured CUDA graphs. The wrinkle: attention with a growing paged KV cache is awkward to capture. The solution is piecewise capture — torch.compile traces the model, then the FX graph is split at attention ops (split_graph, vllm/compilation/backends.py:548), with the splitting-op list defined at vllm/config/compilation.py:745:
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
"vllm::unified_attention_with_output",
"vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2",
...
]
Everything between attentions (MLP, norms, projections — shape-static given a padded token count) gets compiled and captured as cudagraphs by PiecewiseBackend (vllm/compilation/piecewise_backend.py:86); attention runs eager in the gaps. At runtime, CudagraphDispatcher.dispatch (vllm/v1/cudagraph_dispatcher.py:239) pads the batch up to the nearest captured size and selects FULL (pure uniform decode, attention captured too) vs PIECEWISE vs NONE; modes enumerated at vllm/config/compilation.py:53 (FULL_AND_PIECEWISE = decode uses full graphs, mixed prefill batches use piecewise). The wrapper itself is cache-keyed on BatchDescriptor (vllm/compilation/cuda_graph.py:233):
entry = self.concrete_cudagraph_entries[batch_descriptor]
if entry.cudagraph is None:
...
cudagraph = torch.cuda.CUDAGraph()
...
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
…and replay is just entry.cudagraph.replay() (vllm/compilation/cuda_graph.py:360). Replay requires identical input addresses, so the runner owns persistent input buffers and copies each step’s data into them; debug mode asserts data_ptr() equality (:346-355).
7. Attention kernels and dispatch
Models never call a backend directly — every attention layer calls a registered PyTorch custom op which torch.compile treats as opaque (this is what makes graph splitting possible), vllm/model_executor/layers/attention/attention.py:734:
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: LayerNameType,
...
) -> None:
...
layer_name = _resolve_layer_name(layer_name)
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self, query, key, value, kv_cache, attn_metadata, output=output, ...)
(registered via direct_register_custom_op at :777). The impl is chosen at startup by get_attn_backend (vllm/v1/attention/selector.py:54) → platform logic → AttentionBackendEnum (vllm/v1/attention/backends/registry.py:34), which maps names to classes: FLASH_ATTN → vllm/v1/attention/backends/flash_attn.py, FLASHINFER → flashinfer.py, TRITON_ATTN, plus a whole mla/ family for DeepSeek-style latent attention. Override with --attention-backend FLASHINFER (vllm/engine/arg_utils.py:907).
The FlashAttention backend is the default on NVIDIA. Its forward (vllm/v1/attention/backends/flash_attn.py:698) first appends this step’s K/V into the paged cache (reshape_and_cache_flash, :919), then runs flash_attn_varlen_func (:839) over ragged sequences — block_table arg makes it paged-aware, so “PagedAttention” today is a feature of the FA kernel rather than a separate kernel. The original PagedAttention v1/v2 CUDA kernels still exist at csrc/libtorch_stable/attention/paged_attention_v1.cu / paged_attention_v2.cu but are off the main path. Worth reading for your performance instincts, vllm/v1/attention/backends/flash_attn.py:740:
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
8. Streaming egress: OpenAI server, detokenization, stop strings
vllm serve builds a FastAPI app (vllm/entrypoints/openai/api_server.py); /v1/chat/completions lives at vllm/entrypoints/openai/chat_completion/api_router.py:40 and returns StreamingResponse(content=generator, media_type="text/event-stream") (:74). The SSE generator (chat_completion_stream_generator, vllm/entrypoints/openai/chat_completion/serving.py:398) iterates AsyncLLM.generate() and emits chat.completion.chunk deltas, running tool-call/reasoning parsers incrementally per choice.
Underneath, OutputProcessor.process_outputs (vllm/v1/engine/output_processor.py:576) is deliberately “the only function that should loop over EngineCoreOutputs” per batch. Detokenization is incremental and stateful per request: the fast path wraps the HF tokenizers Rust DecodeStream (FastIncrementalDetokenizer, vllm/v1/engine/detokenizer.py:167), primed with the prompt ids so byte-level merges across the prompt/output boundary decode correctly. Stop-string handling is the subtle part — text, not tokens: a stop string can span token boundaries, so the detokenizer withholds max(len(stop)) - 1 chars from the stream (vllm/v1/engine/detokenizer.py:84):
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.stop_buffer_length = 0
check_stop_strings (vllm/v1/engine/detokenizer.py:309) scans only the newly decoded chars (plus the overlap window) and truncates output_text at the match. Because stop strings are detected in the frontend while the EngineCore is already generating the next tokens, the frontend fires an explicit abort back upstream (vllm/v1/engine/output_processor.py:678-681: “If req not finished in EngineCore, but Detokenizer detected stop string, abort needed in EngineCore”) — a classic distributed cancellation race, handled exactly like client-disconnect aborts. Prometheus is mounted as an ASGI sub-app at /metrics (vllm/entrypoints/serve/instrumentator/metrics.py:41); the series definitions are all in vllm/v1/metrics/loggers.py (vllm:num_requests_running :452, vllm:num_requests_waiting :462, vllm:kv_cache_usage_perc :520, vllm:prefix_cache_queries/hits :543/:554, vllm:num_preemptions :620).
Suggested reading path
vllm/v1/engine/core.py— readEngineCore.__init__,step()(:443), andEngineCoreProc.run_busy_loop(:1223) first; everything else hangs off this loop.vllm/v1/core/sched/scheduler.py—schedule()(:353) top to bottom, thenupdate_from_output(:1388). The single most important file in the repo.vllm/v1/core/kv_cache_manager.py—get_computed_blocks(:202) andallocate_slots(:244), including the block-layout ASCII diagram at:290.vllm/v1/core/block_pool.py+vllm/v1/core/kv_cache_utils.py— hash table, free-queue LRU,hash_block_tokens(kv_cache_utils.py:563).vllm/v1/worker/gpu_model_runner.py— skimexecute_model(:4000),sample_tokens(:4379),_bookkeeping_sync(:3557); this is where scheduler abstractions become tensors.vllm/v1/engine/async_llm.py—generate()(:524) and_run_output_handler(:637) for the frontend half; glance atvllm/v1/engine/core_client.py:464for the ZMQ seam.vllm/v1/engine/output_processor.py+vllm/v1/engine/detokenizer.py— egress, stop strings, the abort race.vllm/v1/structured_output/__init__.py+vllm/v1/structured_output/utils.py— the CPU/GPU bitmask handshake.vllm/compilation/cuda_graph.py+vllm/v1/cudagraph_dispatcher.py— capture/replay and dispatch keys.vllm/v1/attention/backends/flash_attn.py— one full backend end-to-end: metadata builder, cache write, varlen kernel call.
Bonus: vllm/v1/core/sched/output.py (SchedulerOutput — the exact wire contract between scheduler and workers) and vllm/v1/request.py (per-request state machine).
Connections to your study set
- nano-vllm — a ~1.2k-line reimplementation of exactly the files above: its scheduler ≈
scheduler.pyminus chunked prefill/connectors, its block manager ≈kv_cache_manager.py+block_pool.pywith the same parent-chained block hashing. Read nano-vllm first or alongside; every concept there has a 10x-more-edge-cases twin here. - sglang — the main competing engine. Same continuous batching and paged KV, but prefix caching is a radix tree over token sequences (RadixAttention) instead of vLLM’s flat hash-of-full-blocks map — compare with
block_pool.py’scached_block_hash_to_block. Its Rustsgl-routerdoes cache-aware routing by approximating each replica’s radix tree; vLLM’s equivalent signal is the--kv-eventsstream (BlockStored/BlockRemovedpublished fromvllm/v1/core/block_pool.py:392) plus prefix-cache metrics. - dynamo — NVIDIA’s Rust distributed layer above engines: KV-aware routing, P/D disaggregation orchestration. It plugs into vLLM through the KV-connector API you saw threaded through the scheduler (
self.connector.get_num_new_matched_tokensatvllm/v1/core/sched/scheduler.py:674,WAITING_FOR_REMOTE_KVSstate at:866) with NIXL as the transfer fabric (vllm/distributed/kv_transfer/kv_connector/v1/). - llm-d — K8s-native distributed inference built around vLLM replicas; its scheduler/EPP consumes the same
/metricsseries defined invllm/v1/metrics/loggers.pyfor load- and cache-aware placement. - gateway-api-inference-extension — your Envoy world: an ext-proc endpoint picker that scrapes
vllm:num_requests_waiting,vllm:kv_cache_usage_perc, and LoRA-adapter metrics from each pod to pick a backend. The chat router even accepts anendpoint-load-metrics-formatheader (vllm/entrypoints/openai/chat_completion/api_router.py:54) so the gateway can get load data inline with responses. - xgrammar — the default structured-output backend; vLLM’s wrapper is
vllm/v1/structured_output/backend_xgrammar.py(compile →fill_next_token_bitmaskon CPU →apply_token_bitmask_inplaceon GPU). vLLM’s Triton port of xgrammar’s apply kernel lives in the experimental runner atvllm/v1/worker/gpu/structured_outputs.py:86. - flashinfer — optional kernel library used in two places: as a full attention backend (
vllm/v1/attention/backends/flashinfer.py, good paged/cascade decode kernels) and for sorting-free top-k/top-p sampling (vllm/v1/sample/ops/topk_topp_sampler.py:471).
Tinkering on one RTX 5080 (16GB)
Blackwell (SM120) needs a recent CUDA 12.8+ wheel — uv pip install vllm --torch-backend=auto handles it. Good fits in 16GB: Qwen/Qwen3-4B-Instruct-2507 (BF16, roomy KV), Qwen/Qwen3-8B-FP8, or RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w4a16. Keep --max-model-len 8192 to start; watch the startup log line that reports how many GPU KV blocks fit.
vllm serve Qwen/Qwen3-4B-Instruct-2507 \
--max-model-len 8192 --gpu-memory-utilization 0.90
- Watch the engine think.
VLLM_LOGGING_LEVEL=DEBUG vllm serve ...showsEngineCore waiting for work, per-shapeCapturing a cudagraph (PIECEWISE, ...)lines at startup, and scheduler activity — map each line to the code above. - Prefix caching A/B. Send 50 requests sharing a 2-4k-token system prompt, then rerun with
--no-enable-prefix-caching. Compare TTFT andvllm:prefix_cache_queriesvsvllm:prefix_cache_hitson/metrics. UsePOST /reset_prefix_cachebetween trials for clean runs; add"cache_salt"to a request body to deliberately fork the hash chain (kv_cache_utils.pyextra_keys). - Chunked prefill / budget pressure. Sweep
--max-num-batched-tokens 512 2048 8192while one client streams a decode-heavy chat and another fires 6k-token prompts. At 512, watch a single prefill get sliced across many steps (decode ITL stays smooth); at 8192, ITL spikes whenever a prefill lands. Then cap with--long-prefill-token-threshold 256and watch the spike flatten — this is themin(num_new_tokens, token_budget)line doing its job. - max-num-seqs sweep + forced preemption.
vllm bench serveagainst--max-num-seqs 16/64/128/256; find where throughput saturates and ITL inflates (KV-bound, not compute-bound). Then set--gpu-memory-utilization 0.5 --max-model-len 16384and long generations to starve the block pool:vllm:num_preemptionsticks up and you’ll see_preempt_requestrecompute-from-zero behavior in tail latencies. - Prometheus to dashboard.
examples/observability/prometheus_grafana/ships aprometheus.yamlscrape config andgrafana.jsondashboard (TTFT/ITL histograms, KV usage, queue depth). Wire it up and reproduce the queue-depth → kv-usage → preemption causality chain during experiment 4 — this is the exact signal set the Envoy inference EPP routes on. - Real traces.
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.jsonthen:
vllm bench serve --model Qwen/Qwen3-4B-Instruct-2507 \
--dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
--request-rate 8 --num-prompts 500
(implementation: vllm/benchmarks/serve.py, dataset loaders in vllm/benchmarks/datasets/datasets.py:1340). Sweep --request-rate to find the knee; compare goodput under Poisson arrivals vs --burstiness.
7. CUDA graph ablation. Run the same decode benchmark three ways: default, --compilation-config '{"cudagraph_mode": "NONE"}', and --enforce-eager (no compile at all). Single-stream decode ITL makes launch overhead visible; on a small model expect a 20-40% gap eager vs graphs.
8. Structured output cost. Fire the same prompts with and without response_format={"type": "json_schema", ...}. Per-step overhead is the CPU fill_next_token_bitmask + H2D mask copy from section 4; it shows up as ITL delta that grows with batch size, and you can profile the grammar_bitmask call inside EngineCore.step with py-spy dump on the EngineCore process — a nice demo that the frontend and engine really are separate PIDs.
nano-vllm
A from-scratch reimplementation of vLLM’s core inference loop in ~1,450 lines of Python (version 0.2.0, commit bb823b3, “chunked-prefill-refactor”). One model family (Qwen3 dense), one platform (CUDA), offline batch only — but the four mechanisms that make modern LLM serving fast are all here, each in a single readable file: continuous batching, paged KV cache, hash-based prefix caching, and CUDA-graph decode. It even includes tensor parallelism over a homemade shared-memory RPC, and chunked prefill.
Why read this before real vLLM: vLLM V1 implements the same architecture — same queues, same block hashing, same prefill/decode batch preparation — but spread across hundreds of files with feature flags, backends, and abstraction layers. nano-vllm is the architecture with the adjectives removed. After this, opening vllm/v1/core/sched/scheduler.py feels like reading a familiar program with more options, not an alien codebase. The author benchmarks it slightly faster than vLLM on small models (README.md: 1434 vs 1362 tok/s, Qwen3-0.6B on an 8GB laptop GPU), so it is not a toy that cheats.
Why you care
- The scheduler is an admission controller.
Scheduler.schedule()is 49 lines that decide, every iteration, which requests get GPU time against a token budget and a block pool — structurally the same problem as adaptive concurrency / load shedding in Envoy, except the contended resource is KV-cache blocks instead of connections. Preemption here is literally load shedding with retry-from-front-of-queue. - Prefix caching is content-addressable storage with refcounts. Once you’ve read
BlockManager.hash_blocks, the “KV-cache-aware routing” in llm-d, dynamo, and gateway-api-inference-extension stops being marketing: those routers are trying to predict, from outside the engine, whathash_to_block_idcontains inside each replica. - It defines the signal set. Queue depths, free blocks, prefill-vs-decode token throughput, preemption events — everything a serving-layer LB wants exported from an engine is a one-line
printaway in this codebase. You can watch the whole control plane think in a terminal.
Complete file map
| File | Lines | Purpose |
|---|---|---|
nanovllm/__init__.py | 2 | Exports LLM and SamplingParams |
nanovllm/llm.py | 5 | class LLM(LLMEngine): pass — the vLLM-compatible facade |
nanovllm/config.py | 25 | Config dataclass: budgets, block size, TP size; pulls HF config; asserts model path is a local dir |
nanovllm/sampling_params.py | 11 | temperature / max_tokens / ignore_eos — that’s the entire sampling API (greedy is asserted away) |
nanovllm/engine/llm_engine.py | 90 | Engine: spawns TP worker processes, owns tokenizer + scheduler, runs the step() loop, tracks tok/s |
nanovllm/engine/scheduler.py | 92 | Continuous batching: waiting/running deques, chunked prefill, preemption, finish detection |
nanovllm/engine/block_manager.py | 120 | Paged KV allocator + prefix cache via chained xxhash64 block hashing, refcounts, lazy eviction |
nanovllm/engine/sequence.py | 83 | Sequence: per-request state machine (WAITING/RUNNING/FINISHED), block table, token counters |
nanovllm/engine/model_runner.py | 257 | Per-GPU worker: KV memory sizing, batch tensor prep, CUDA graph capture/replay, sampling, shm RPC for TP |
nanovllm/layers/attention.py | 75 | Triton kernel writing K/V into paged cache + flash-attn varlen (prefill) / with-kvcache (decode) |
nanovllm/layers/sampler.py | 12 | Temperature sampling via the Gumbel/exponential trick, torch.compiled |
nanovllm/layers/linear.py | 156 | TP linear layers (Column/Row/QKV/Merged) with per-param weight_loader — vLLM’s loading idiom |
nanovllm/layers/embed_head.py | 66 | Vocab-parallel embedding; LM head that slices only each sequence’s last position before projecting |
nanovllm/layers/rotary_embedding.py | 59 | Precomputed cos/sin RoPE cache |
nanovllm/layers/layernorm.py | 50 | RMSNorm with fused residual-add variant |
nanovllm/layers/activation.py | 11 | SiLU-and-mul (SwiGLU gate) |
nanovllm/models/qwen3.py | 216 | Qwen3 decoder assembled from the layers above; packed_modules_mapping for fused-weight loading |
nanovllm/utils/context.py | 27 | Process-global Context dataclass carrying attention metadata into the forward pass |
nanovllm/utils/loader.py | 28 | Safetensors loader dispatching to per-parameter weight_loaders |
example.py | 33 | Chat-template generation demo (Qwen3-0.6B) |
bench.py | 32 | 256-sequence random-length throughput benchmark, vLLM-comparable |
Core mechanisms
The engine step loop
LLMEngine.generate() (nanovllm/engine/llm_engine.py:60-90) enqueues all prompts, then spins step() until both queues drain. Each step is one scheduler decision plus one GPU forward:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids, is_prefill)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
return outputs, num_tokens
(nanovllm/engine/llm_engine.py:49-55.) The key design choice: every batch is either pure prefill or pure decode — schedule() returns a flag, and the model runner builds entirely different tensors for each case. Prefill has strict priority: decode only runs when the waiting queue is empty or blocked on memory. (vLLM V1 abandoned this dichotomy — see the comparison table.) The num_tokens sign trick (positive = prefill tokens, negative = decode batch size) only feeds the tqdm throughput display.
Block manager: paged KV + prefix caching
BlockManager (nanovllm/engine/block_manager.py) owns num_kvcache_blocks logical blocks of 256 tokens each (vLLM defaults to 16). A Block is {block_id, ref_count, hash, token_ids}. Prefix caching is a chained content hash — each block’s hash commits to the entire prefix before it, so one dict lookup matches an arbitrarily long prefix:
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
(nanovllm/engine/block_manager.py:35-41.) On admission, can_allocate (nanovllm/engine/block_manager.py:58-73) walks the prompt’s full blocks, following the hash chain until the first miss:
h = -1
num_cached_blocks = 0
num_new_blocks = seq.num_blocks
for i in range(seq.num_blocks - 1):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
break
num_cached_blocks += 1
if block_id in self.used_block_ids:
num_new_blocks -= 1
if len(self.free_block_ids) < num_new_blocks:
return -1
return num_cached_blocks
Three details worth noticing: the stored token_ids comparison guards against hash collisions; the loop deliberately stops at num_blocks - 1 so the last (possibly partial) block is always recomputed — which also guarantees at least one query token exists to produce logits from; and -1 means “not enough memory, stall admission.”
allocate (nanovllm/engine/block_manager.py:75-92) then bumps refcounts on hits — including resurrecting blocks that are sitting on the free list but still hold valid data. That’s the eviction policy: deallocation just appends to the free deque without clearing the hash, so a “freed” block remains a cache hit until _allocate_block (nanovllm/engine/block_manager.py:43-51) actually reuses it and unregisters its hash:
for i in range(num_cached_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id[h]
block = self.blocks[block_id]
if block_id in self.used_block_ids:
block.ref_count += 1
else:
block.ref_count = 1
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
During decode, growth is incremental — one new block exactly when the latest token crossed a block boundary, with a bool-as-int trick:
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
if len(seq) % self.block_size == 1:
seq.block_table.append(self._allocate_block())
(nanovllm/engine/block_manager.py:103-108.) Finally, after every model step, hash_blocks registers any block that just became full — note this runs during decode too, so generated tokens become prefix-cacheable (multi-turn reuse for free):
def hash_blocks(self, seq: Sequence):
start = seq.num_cached_tokens // self.block_size
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
if start == end: return
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
for i in range(start, end):
block = self.blocks[seq.block_table[i]]
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block.update(h, token_ids)
self.hash_to_block_id[h] = block.block_id
(nanovllm/engine/block_manager.py:110-120.)
Continuous batching, chunked prefill, preemption
Scheduler (nanovllm/engine/scheduler.py) keeps two deques, waiting and running. The prefill pass (nanovllm/engine/scheduler.py:30-52) admits from waiting FCFS while under max_num_seqs (512) and max_num_batched_tokens (16384), discounting prefix-cache hits from the work and chunking long prompts:
if not seq.block_table:
num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
break
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
else:
num_tokens = seq.num_tokens - seq.num_cached_tokens
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
if not seq.block_table:
self.block_manager.allocate(seq, num_cached_blocks)
seq.num_scheduled_tokens = min(num_tokens, remaining)
(nanovllm/engine/scheduler.py:35-46.) A sequence whose prompt doesn’t fit in the remaining budget gets a partial chunk (only as the first seq of the batch) and stays in waiting with its block_table already fully allocated; it’s promoted to running only once the whole prompt has been scheduled (nanovllm/engine/scheduler.py:48-51). Blocks for the entire prompt are allocated up front at admission — simpler than vLLM, which allocates as chunks proceed.
If nothing was admitted, the decode pass runs — and this is where requests leave the batch under memory pressure:
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
(nanovllm/engine/scheduler.py:58-72.) Victim selection is self.running.pop() — the youngest running sequence is sacrificed for the oldest (LIFO shedding, FCFS service). preempt (nanovllm/engine/scheduler.py:75-79) frees all the victim’s blocks and pushes it to the front of waiting: preemption is full recompute, no CPU swap — softened in practice because its full blocks are still hash-registered, so re-prefill is mostly cache hits if they haven’t been reused. The while/else is Python’s no-break clause; the extendleft(reversed(...)) returns scheduled seqs to the front in original order.
postprocess (nanovllm/engine/scheduler.py:81-92) closes the loop — registers new block hashes, advances counters, appends the sampled token (discarding it for mid-prompt chunks), and retires finished sequences:
for seq, token_id in zip(seqs, token_ids):
self.block_manager.hash_blocks(seq)
seq.num_cached_tokens += seq.num_scheduled_tokens
seq.num_scheduled_tokens = 0
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
continue
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)
That continue is chunked prefill’s tell: the sampler always samples, but the token is thrown away unless this chunk completed the prompt.
Model runner: forward pass, CUDA graphs, sampling
ModelRunner (nanovllm/engine/model_runner.py) is one process per TP rank. Rank 0 lives in the engine process; ranks 1..N-1 are spawned (nanovllm/engine/llm_engine.py:24-31) and block in loop() reading method calls from a 1 MB SharedMemory segment signaled by mp.Event (nanovllm/engine/model_runner.py:61-89) — a ~30-line RPC standing in for vLLM’s ZMQ executor. Sequence.__getstate__ (nanovllm/engine/sequence.py:72-74) keeps that pickle tiny: decode steps ship only the last token, not the whole sequence.
At startup it runs a max-size dummy prefill to measure peak activation memory, then sizes the KV cache to fill whatever is left and aliases slices of one giant tensor into every attention layer:
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
(nanovllm/engine/model_runner.py:110-121.) Per step, run() (nanovllm/engine/model_runner.py:214-220) builds tensors via prepare_prefill (ragged varlen batch: flattened input_ids, cu_seqlens_q/k, and a slot_mapping from each token to its physical cache slot — nanovllm/engine/model_runner.py:129-170) or prepare_decode — one token per sequence:
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
(nanovllm/engine/model_runner.py:177-181.) The metadata travels to the attention layers via a process-global Context (nanovllm/utils/context.py:21-23) — nano’s version of vLLM’s ForwardContext, sidestepping threading metadata through nn.Module.forward signatures.
Decode steps with batch ≤ 512 replay a pre-captured CUDA graph instead of relaunching kernels (capture_cudagraph at nanovllm/engine/model_runner.py:222-257 captures graphs for batch sizes 1, 2, 4, 8, 16, 32, … sharing one memory pool). Replay = copy inputs into static buffers, fire:
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
(nanovllm/engine/model_runner.py:200-212.) The slot_mapping.fill_(-1) padding is honored by the Triton store kernel, which skips slot −1 (nanovllm/layers/attention.py:23). Prefill always runs eager — graphs only pay off for the many small identical decode launches.
Inside each attention layer, the whole paged-attention story is two flash-attn calls plus the cache-write kernel:
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
(nanovllm/layers/attention.py:62-74.) Sampling happens on rank 0 only — ParallelLMHead first slices each sequence’s last position (nanovllm/layers/embed_head.py:58-60) so logits are computed for one token per sequence, TP shards are gathered to rank 0 (nanovllm/layers/embed_head.py:62-65), then the entire sampler is six lines of Gumbel-max:
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens
(nanovllm/layers/sampler.py:7-12.) Dividing probabilities by i.i.d. Exponential(1) noise and taking argmax is an exact categorical sample — one fused kernel, no top-k/top-p machinery.
What a Sequence carries
Sequence (nanovllm/engine/sequence.py) is the unit the scheduler moves around: token list, num_prompt_tokens, num_cached_tokens (KV entries that already exist — from prefix hits, completed chunks, and each decode step), num_scheduled_tokens (work assigned this step), its block_table (logical→physical block ids), and the three sampling fields. The class-level block_size is injected from config at engine startup (nanovllm/engine/llm_engine.py:21).
add_request() schedule(): blocks allocated, postprocess(): EOS or
┌─────────┐ prompt chunks scheduled ┌─────────┐ max_tokens reached ┌──────────┐
arrival ──►│ WAITING │ ──────────────────────────────►│ RUNNING │────────────────────►│ FINISHED │
└─────────┘ └─────────┘ (blocks freed, └──────────┘
▲ │ seq leaves running)
└──────── preempt(): ALL blocks freed ◄────┘
re-enters at FRONT of waiting; re-prefill is
cheap if its blocks are still in the hash map
There is no SWAPPED state (nanovllm/engine/sequence.py:8-11) — vLLM V0 had one; nano (like vLLM V1’s default) preempts by recompute.
nano-vllm → vLLM V1 map
| nano-vllm | vLLM V1 equivalent | Delta |
|---|---|---|
LLMEngine (nanovllm/engine/llm_engine.py) | LLMEngine / AsyncLLM + EngineCore (vllm/v1/engine/llm_engine.py, core.py) | vLLM splits the core loop into its own process behind ZMQ; adds async streaming |
Scheduler (nanovllm/engine/scheduler.py) | Scheduler (vllm/v1/core/sched/scheduler.py) | V1 erased the prefill/decode dichotomy: one token-budget loop assigns num_scheduled_tokens per request and mixes both in a single batch; adds priority, spec-decode tokens, structured-output gating |
waiting / running deques | same-named queues in V1’s scheduler | V1 adds priority policies; same FCFS default |
Sequence (nanovllm/engine/sequence.py) | Request (vllm/v1/request.py) + GPU-side CachedRequestState (vllm/v1/worker/gpu_input_batch.py) | V0’s Sequence/SequenceGroup is the historical ancestor |
BlockManager (nanovllm/engine/block_manager.py) | KVCacheManager (vllm/v1/core/kv_cache_manager.py) + BlockPool (block_pool.py), hashing in kv_cache_utils.py | Same chained block hashing; vLLM adds LRU eviction, hybrid KV groups (sliding window etc.), and an explicit free-queue order |
num_cached_tokens / chunk accounting | num_computed_tokens on Request | Same concept, same role in chunked prefill |
ModelRunner (nanovllm/engine/model_runner.py) | GPUModelRunner + Worker (vllm/v1/worker/gpu_model_runner.py, gpu_worker.py) | vLLM keeps a persistent on-GPU input batch updated by diffs instead of rebuilding tensors per step |
SharedMemory + Event RPC (nanovllm/engine/model_runner.py:61-89) | MultiprocExecutor (vllm/v1/executor/multiproc_executor.py) shm message queues; Ray for multi-node | Same shape: rank 0 broadcasts method calls to workers |
Context global (nanovllm/utils/context.py) | ForwardContext (vllm/forward_context.py) + per-backend AttentionMetadata | Identical trick: process-global side channel into attention |
flash-attn calls (nanovllm/layers/attention.py:67-74) | pluggable backends (vllm/v1/attention/backends/: FlashAttention, FlashInfer, Triton, FlexAttention…) | nano hardcodes one backend |
capture_cudagraph (nanovllm/engine/model_runner.py:222-257) | full + piecewise CUDA graphs via torch.compile (vllm/compilation/) | piecewise = graph the transformer chunks, leave attention eager |
Sampler (nanovllm/layers/sampler.py) | Sampler (vllm/v1/sample/sampler.py) | vLLM adds top-k/p, penalties, logprobs, seeds, bad-words, structured-output bitmasks |
loader.py + weight_loader attrs | vllm/model_executor/model_loader/ | Same per-parameter loader idiom, same packed_modules_mapping for fused QKV/gate-up |
What’s deliberately omitted
Do not expect to find any of the following — knowing they’re absent is what keeps the codebase at 1.4k lines:
- Serving: no API server, no OpenAI endpoints, no streaming, no async engine —
generate()is blocking batch; detokenization happens once at the end (nanovllm/engine/llm_engine.py:89). - Sampling: no top-k/top-p, penalties, logprobs, stop strings, seeds, beam search — and greedy decoding is explicitly forbidden (
nanovllm/sampling_params.py:11). - Structured output: no grammar/JSON-schema constrained decoding (no logits-processor hook at all).
- Quantization: none (no GPTQ/AWQ/FP8) — checkpoints load at the HF config dtype, period.
- Adapters & spec decode: no LoRA/multi-LoRA, no speculative decoding (EAGLE/Medusa/ngram).
- Models: Qwen3 dense only — no MoE, no multimodal, no sliding-window/hybrid attention, no encoder-decoder.
- Parallelism: single-node TP only (≤8, NCCL hardcoded); no pipeline/expert/data parallel, no Ray, no multi-node.
- KV mobility: preemption is recompute-only — no CPU swap, no KV offload tiers, no KV connectors / disaggregated prefill-decode transfer.
- Operations: no metrics, tracing, health checks, or scheduling policies beyond FCFS.
- Portability: CUDA only — no ROCm/CPU/TPU backends.
Connections to your other study repos
- vllm — direct target. Read nano’s scheduler, then
vllm/v1/core/sched/scheduler.py; nano’sBlockManager, thenvllm/v1/core/kv_cache_manager.py+block_pool.py. The single most instructive diff: V1 has nois_prefillflag anywhere in scheduling — convince yourself nano’s two-phaseschedule()collapses into V1’s one token-budget loop. - sglang — same engine loop, different prefix-cache data structure: RadixAttention keeps a radix tree over token sequences (
python/sglang/srt/mem_cache/radix_cache.py) where nano keeps a flat hash map; the tree enables cache-aware scheduling (reordering the waiting queue by match length — contrast nano’s strict FCFS) and partial-block matching. sglang’sSchedulerispython/sglang/srt/managers/scheduler.py. - flashinfer — nano’s entire kernel contract is three calls: the Triton cache-write plus
flash_attn_varlen_func/flash_attn_with_kvcache(nanovllm/layers/attention.py). flashinfer provides exactly those two attention ops (BatchPrefill/BatchDecode over paged KV) with a plan/run split and JIT specialization — it’s what vLLM/sglang slot into the seam nano hardcodes. - xgrammar — would plug in between
compute_logitsand the sampler (nanovllm/engine/model_runner.py:217-218) as a per-sequence token bitmask applied to logits. nano having no hook there is the lesson: you can see precisely where structured output costs latency and why vLLM gates “waiting for FSM compilation” in the scheduler. - dynamo — nano’s strict per-step prefill/decode mutual exclusion is the single-GPU shadow of disaggregated P/D serving: dynamo moves that split across worker pools and ships KV via NIXL. Its KV-aware router consumes events about exactly the state that is
hash_to_block_idhere. - llm-d / gateway-api-inference-extension — these balance across replicas of engines like this one. The EPP scorers (prefix-cache-aware, queue-depth, load) are estimating, from outside, what
can_allocatewould return and how longwaitingis inside each pod. The instrumentation section below prints precisely the signal set an inference gateway wants engines to export.
Running it on your RTX 5080 (16 GB)
Use WSL2, not native Windows. Three hard blockers natively: dist.init_process_group("nccl", ...) is hardcoded even for TP=1 (nanovllm/engine/model_runner.py:26) and Windows torch builds ship no NCCL; Triton has no official Windows wheels (the store_kvcache kernel and torch.compile need it); flash-attn has no official Windows wheels. All three work in WSL2 Ubuntu with the standard CUDA-on-WSL driver. (Fallback hack if you must stay native: locally change "nccl" to "gloo", install triton-windows, and hunt a community flash-attn wheel — WSL2 is less total work.)
Blackwell specifics. The 5080 is sm_120: you need a cu128 build of PyTorch (≥2.7) and a flash-attn built against it — there’s frequently no prebuilt sm_120 wheel, so expect a source build:
pip install torch --index-url https://download.pytorch.org/whl/cu128
MAX_JOBS=4 TORCH_CUDA_ARCH_LIST="12.0" pip install flash-attn --no-build-isolation # long compile
pip install transformers xxhash huggingface_hub
huggingface-cli download Qwen/Qwen3-4B --local-dir ~/huggingface/Qwen3-4B/
Config asserts the model path is a local directory (nanovllm/config.py:21) — you cannot pass a hub id.
Model choice. No quantization support means weights load in bf16, so the ceiling is lower than you’d guess. KV cost per 256-token block is 2 * layers * 256 * kv_heads * head_dim * 2 bytes (nanovllm/engine/model_runner.py:112):
| Model | bf16 weights | KV per token | KV per block | Fits 16 GB? |
|---|---|---|---|---|
| Qwen3-0.6B | ~1.4 GB | 112 KiB | 28 MiB | Yes — huge cache (~100k tokens), best for scheduler experiments |
| Qwen3-1.7B | ~3.4 GB | 112 KiB | 28 MiB | Yes, comfortable |
| Qwen3-4B | ~8 GB | 144 KiB | 36 MiB | Yes — the sweet spot; roughly 130–160 blocks (~35–40k KV tokens) left |
| Qwen3-8B | ~16.4 GB | — | — | No. Quantized 8B isn’t possible here; do Qwen3-8B-AWQ in real vLLM instead |
Start with enforce_eager=True (skips capturing ~36 CUDA graphs; far friendlier stack traces), gpu_memory_utilization=0.8 if your desktop compositor holds VRAM, and check what the profiler decided: llm.config isn’t exposed, but llm.scheduler.block_manager.free_block_ids length right after construction equals num_kvcache_blocks.
Watching the scheduler think. Without touching repo files, monkeypatch from your own driver script:
from nanovllm import LLM, SamplingParams
from nanovllm.engine.scheduler import Scheduler
orig_schedule, orig_preempt = Scheduler.schedule, Scheduler.preempt
def schedule(self):
seqs, is_prefill = orig_schedule(self)
print(f"{'PREFILL' if is_prefill else 'DECODE '}"
f" batch={len(seqs)}"
f" sched_tokens={sum(s.num_scheduled_tokens for s in seqs)}"
f" waiting={len(self.waiting)} running={len(self.running)}"
f" free_blocks={len(self.block_manager.free_block_ids)}"
f" cached_hashes={len(self.block_manager.hash_to_block_id)}")
return seqs, is_prefill
def preempt(self, seq):
print(f" PREEMPT seq={seq.seq_id} len={len(seq)} (recompute later)")
orig_preempt(self, seq)
Scheduler.schedule, Scheduler.preempt = schedule, preempt
Experiments that make each mechanism visible:
- Prefix cache: call
generate()twice with the same long prompt (>512 tokens, so it spans ≥2 full blocks). The second run’s PREFILL line showssched_tokensshrunk by 256 per matched block — andseq.num_cached_tokensconfirms the hits. Note hits land only across steps/calls: two identical prompts admitted in the same prefill batch don’t dedupe, becausehash_blocksruns in postprocess. - Chunked prefill: set
max_num_batched_tokens=1024, send a ~4k-token prompt — successive PREFILL steps of 1024 for one sequence, which joinsrunningonly on the last chunk. - Preemption: shrink the pool (
gpu_memory_utilization=0.5) and runbench.py-style load — many sequences withignore_eos=True,max_tokens=1024. Whenfree_blockshits 0 mid-decode, watch PREEMPT pick the youngest sequences and note their later re-prefill is cheap (cached_hashes still holds their blocks). - Continuous batching itself: give 32 prompts wildly different
max_tokens; the DECODEbatch=count steps down as short sequences finish while long ones keep going — no stragglers waiting on a static batch, no head-of-line blocking. - CUDA graphs: flip
enforce_eagerand compare the tqdm decode tok/s; then readrun_model(nanovllm/engine/model_runner.py:196-212) knowing exactly what you paid for.
SGLang
1. What it is, and how it differs from vLLM
SGLang is an LLM serving stack from the LMSYS team with two load-bearing halves: a Python serving runtime (“SRT”, python/sglang/srt/) that owns the GPU — continuous batching, paged KV, RadixAttention prefix caching, grammar-constrained decoding — and a production Rust router/data plane (sgl-model-gateway/, historically named sgl-router) that load-balances a fleet of those runtimes. Versus vLLM: both do continuous batching + paged KV + prefix caching, but SGLang’s prefix cache is a first-class radix tree over token IDs (vLLM v1 hashes fixed-size blocks; SGLang matches at arbitrary token granularity within page alignment and schedules around the tree — its waiting-queue policy literally sorts by longest-prefix-match against cache contents), its scheduler overlap (“zero-overhead scheduler”) runs CPU bookkeeping for batch N concurrently with GPU forward of batch N+1 by default, it retracts decode requests back to the waiting queue under memory pressure rather than vLLM-style preemption semantics, and — the part that matters for you — it ships a cross-replica cache-aware router as a supported component rather than leaving fleet-level routing to an external gateway. The repo also carries sgl-kernel/ (AOT CUDA kernels), a frontend DSL (python/sglang/lang/ — the original “Structured Generation Language” that gave the project its name; co-pilot-style programs with gen() calls; you can skip it), and a gRPC protocol (proto/, rust/sglang-grpc/) that lets the Rust gateway talk to schedulers while doing tokenization in Rust.
2. Why you care
You sketched cache-aware routing with prefix trees before you knew this existed. sgl-router is the production version of that exact sketch, and it validates every design intuition you’d have brought from Envoy land: the router keeps an approximate radix tree per worker, built purely from routing history (no worker telemetry needed), routes to the worker with the highest prefix match when load is balanced, and falls back to shortest-queue when an imbalance detector trips — with the imbalance test being the same abs+rel two-threshold pattern you’d use for outlier detection in an LB. The entire policy is ~500 lines of Rust in sgl-model-gateway/src/policies/cache_aware.rs plus a ~2300-line lock-free radix tree in tree.rs, and reading it is the fastest way to see which corners production cut that your sketch probably didn’t (char-level text instead of tokens, optimistic insert-at-selection-time, zero feedback loop from workers). It also gives you the counterpoint to study against NVIDIA Dynamo’s KV-event-based “ground truth” approach and the Envoy-based gateway-api-inference-extension — three points on the design space of “how much does the LB know about cache state, and who pays for that knowledge.”
3. Architecture map
sglang/ (repo root)
├── python/sglang/
│ ├── lang/ # frontend DSL (skippable)
│ ├── launch_server.py # CLI: python -m sglang.launch_server
│ ├── bench_serving.py # load generator (has shared-prefix dataset)
│ └── srt/ # THE RUNTIME
│ ├── entrypoints/http_server.py # FastAPI app (OpenAI-compatible + /generate)
│ ├── entrypoints/engine.py # spawns the 3-process pipeline
│ ├── managers/tokenizer_manager.py # proc 1: tokenize, track request state
│ ├── managers/scheduler.py # proc 2 (per TP rank): batching, KV, the GPU loop
│ ├── managers/schedule_batch.py # Req/ScheduleBatch state machines
│ ├── managers/schedule_policy.py # waiting-queue priority (LPM) + PrefillAdder
│ ├── managers/detokenizer_manager.py # proc 3: incremental detokenization
│ ├── mem_cache/radix_cache.py # RadixAttention's tree (the real one)
│ ├── mem_cache/memory_pool.py # req_to_token + token_to_kv paged pools
│ ├── layers/radix_attention.py # nn.Module that dispatches to attn backends
│ ├── layers/attention/ # flashinfer_backend.py, triton_backend.py, ...
│ ├── constrained/ # xgrammar/llguidance/outlines bitmask backends
│ └── disaggregation/ # prefill/decode (PD) split runtime
├── sgl-model-gateway/ # THE ROUTER (formerly sgl-router; pip: sglang-router)
│ ├── src/policies/ # cache_aware.rs, tree.rs, power_of_two.rs, ...
│ ├── src/core/ # worker.rs (health/load/CB), worker_registry, retry
│ ├── src/routers/http/ # router.rs (regular), pd_router.rs (PD pairs)
│ ├── src/routers/grpc/ # Rust-native tokenizer pipeline → SRT gRPC workers
│ └── bindings/python/ # python -m sglang_router.launch_router
├── sgl-kernel/ # AOT CUDA/C++ kernels (own wheel)
└── proto/, rust/sglang-grpc/ # gRPC contract between gateway and scheduler
Process pipeline inside one worker: HTTP → TokenizerManager → ZMQ → Scheduler (one process per GPU; owns model + KV) → ZMQ → DetokenizerManager → ZMQ back to TokenizerManager → SSE out. Spawning is in python/sglang/srt/entrypoints/engine.py:584 (scheduler procs) and :196 (process wiring). Note ZMQ IPC, not gRPC, between these processes — the gRPC path is an alternative front door used by the Rust gateway (python/sglang/srt/entrypoints/grpc_server.py).
4. Core mechanisms
4.1 sgl-router: cache-aware load balancing (Rust)
Everything below lives under sgl-model-gateway/src/. The policy abstraction is LoadBalancingPolicy (policies/mod.rs:42) — select_worker(workers, info) -> Option<usize> where SelectWorkerInfo (policies/mod.rs:160) carries request_text (HTTP path), tokens (gRPC path), headers, and a consistent-hash ring. Policies: random, round_robin, cache_aware, power_of_two, prefix_hash, consistent_hashing, bucket, manual. Default is cache_aware (src/main.rs:150).
The routing decision
CacheAwarePolicy::select_worker (policies/cache_aware.rs:376) is a two-mode policy. First it computes min/max in-flight load across workers and runs the imbalance detector:
#![allow(unused)]
fn main() {
// Check if load is imbalanced
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
if is_imbalanced {
return self.select_worker_min_load(
workers,
&request_text,
&healthy_indices,
&tree_key,
max_load,
min_load,
);
}
}
— policies/cache_aware.rs:400-413. Both conditions must hold (abs gap AND ratio), the same belt-and-suspenders you’d use so that tiny absolute deltas at low traffic and proportionally-small deltas at high traffic both stay in cache-affinity mode. CLI defaults: cache_threshold 0.3, balance_abs_threshold 64, balance_rel_threshold 1.5, eviction every 120s, max_tree_size 67108864 chars (src/main.rs:153-171; the in-code CacheAwareConfig::default() at policies/mod.rs:106-116 differs — CLI wins).
When balanced, it consults the approximate tree:
#![allow(unused)]
fn main() {
let result = tree.prefix_match_with_counts(text);
let match_rate = if result.input_char_count == 0 {
0.0
} else {
result.matched_char_count as f32 / result.input_char_count as f32
};
// Select worker without String allocation
let selected_idx = if match_rate > self.config.cache_threshold {
// Cache hit path: find worker by URL (compare &str directly, no allocation)
let tenant_url: &str = &result.tenant;
workers
.iter()
.position(|w| w.url() == tenant_url)
.filter(|&idx| workers[idx].is_healthy())
}
— policies/cache_aware.rs:425-439. Below the threshold, the else-branch (:440-446) picks the min-load healthy worker — i.e., “no one has meaningful affinity for this prefix, so place it where there’s most headroom” (the comment block at :29-30 frames it as “smallest tree = most available cache capacity”; the implementation uses live load). Then the critical learning step — the router inserts the request text under the chosen worker at selection time:
#![allow(unused)]
fn main() {
if let Some(idx) = selected_idx {
// Update the tree with this request (use worker URL directly, no allocation)
tree.insert(text, workers[idx].url());
}
— policies/cache_aware.rs:448-450. That insert is what builds each worker’s “approximate cache state”: the router assumes that whatever it routes to worker W will end up in W’s real radix cache. Even in imbalanced/min-load mode it still inserts (:335-343), so affinity keeps tracking reality during load-balancing episodes. If the matched tenant is dead, the policy removes it from the tree and falls back (:471-491).
The text being matched comes from extract_text_for_routing() on the typed request (routers/http/router.rs:203) — raw prompt/messages text, characters not tokens. The gRPC router can instead route on token IDs via PrefixHashPolicy (policies/prefix_hash.rs, hashing the first --prefix-token-count 256 tokens to a consistent-hash ring — a cheaper, stateless cousin of cache-aware).
The approximate radix tree
policies/tree.rs is a concurrent multi-tenant radix tree: one tree shared by all workers of a (pool, model) group, where each node tracks which workers (tenants) own that prefix segment:
#![allow(unused)]
fn main() {
#[derive(Debug)]
struct Node {
/// Children nodes indexed by first character.
/// Using custom hasher optimized for char keys.
children: DashMap<char, NodeRef, CharHasherBuilder>,
/// Node text with cached character count
text: RwLock<NodeText>,
/// Per-tenant last access epoch for LRU ordering. Using TenantId (Arc<str>) for cheap cloning.
tenant_last_access_time: DashMap<TenantId, u64>,
/// Parent pointer for upward traversal during timestamp updates
parent: RwLock<Option<NodeRef>>,
/// Cached last-accessed tenant for O(1) lookup during prefix match.
/// Avoids O(shards) DashMap iteration in the common case.
last_tenant: parking_lot::RwLock<Option<TenantId>>,
}
}
— policies/tree.rs:229-243. Notes a traffic-infra person will appreciate: TenantId = Arc<str> interning (:40), per-node DashMaps with tuned shard counts (32 for root, 8 elsewhere, :23-24), and LRU timestamps from a global atomic epoch counter instead of wall clock (:218-227 — “no syscall overhead… perfectly monotonic”). insert (:362) walks char-by-char, splitting nodes on partial matches (:435-482) like any radix trie; prefix_match_with_counts (:531) walks down and returns the deepest node’s tenant. Two deliberate approximations in the hot path: tenant resolution uses the last_tenant cache or whatever tenant the DashMap iterator yields first (:568-599) — when multiple workers own the same prefix it does not pick the least-loaded owner — and LRU timestamps update only 1-in-8 matches:
#![allow(unused)]
fn main() {
// Update timestamp probabilistically (1 in 8 matches) to reduce DashMap contention.
// LRU eviction doesn't need perfect accuracy - approximate timestamps suffice.
let epoch = get_epoch();
if epoch & 0x7 == 0 {
curr.tenant_last_access_time
.insert(Arc::clone(&tenant), epoch);
}
}
— policies/tree.rs:601-607.
Trees are keyed pool::model ("prefill::llama" / "decode::llama" / "regular::UNKNOWN_MODEL_ID", policies/cache_aware.rs:83-100) so in PD mode the prefill and decode pools can’t evict each other’s tenants — the doc comment at :76-82 records the regression where a shared tree collapsed into pool flip-flopping. Workers are seeded with an empty-string insert at registration (:185, :197).
Tree maintenance and eviction
A background thread runs every eviction_interval_secs and trims each tree to a per-tenant character budget (policies/cache_aware.rs:127-149). The eviction itself (policies/tree.rs:718) collects per-tenant leaves into a min-heap by LRU epoch and peels leaves until the tenant is under budget:
#![allow(unused)]
fn main() {
// Process eviction
while let Some(Reverse(entry)) = pq.pop() {
let EvictionEntry { tenant, node, .. } = entry;
if let Some(used_size) = self.tenant_char_count.get(tenant.as_ref()) {
if *used_size <= max_size {
continue;
}
}
}
— policies/tree.rs:745-753; when a leaf’s removal makes its parent a new leaf for that tenant, the parent is pushed back into the heap (:793-814), so eviction unwinds bottom-up exactly like the runtime’s real tree. remove_tenant (:823) does the same unwinding for worker removal. This mirrors your sketch’s “evict LRU leaves” but note the unit: characters per tenant, with no relationship to the worker’s actual KV-token budget.
Worker health, load, and failure handling
core/worker.rs is the Envoy-cluster-management analog. Each worker has an AtomicBool health flag, an atomic in-flight load counter, a circuit breaker, and a background health checker hitting its health endpoint. Health transitions use hysteresis — failure_threshold consecutive failures to go down, success_threshold consecutive successes to come back:
#![allow(unused)]
fn main() {
if self.is_healthy()
&& failures >= self.metadata.health_config.failure_threshold as usize
{
self.set_healthy(false);
self.consecutive_failures.store(0, Ordering::Release);
}
}
— core/worker.rs:740-745 (the symmetric success path is :726-731). Policy-visible “load” is the router’s own in-flight counter, maintained RAII-style — and only for policies that need it:
#![allow(unused)]
fn main() {
let load_guard = ["cache_aware", "manual"]
.contains(&policy.name())
.then(|| WorkerLoadGuard::new(worker.clone(), headers));
}
— routers/http/router.rs:300-302. WorkerLoadGuard::new increments; Drop decrements (core/worker.rs:1109-1140), and the guard is attached to the streaming response body (AttachedBody, :1147) so load stays held until the SSE stream actually finishes — the same trick as deferred-done callbacks in Envoy. Policy eligibility = healthy AND circuit breaker closed (policies/mod.rs:136-143). Separately, a LoadMonitor (core/worker_manager.rs:269) polls each worker’s /v1/loads?include=core and extracts aggregate.total_tokens (:207-230) — but only power_of_two consumes that token-level signal (policies/power_of_two.rs:113-117, with graceful degradation to request counts at :75-85). Cache-aware never sees worker-reported load.
Where the approximation can go stale
This is the part to study hardest, because it’s the gap between your sketch and what you’d need to productionize one:
- Characters vs tokens vs pages. Router matches raw text chars; the worker’s real tree matches token IDs, page-aligned (
radix_cache.py:380). Chat-template expansion, tokenizer merges, and page truncation all makematch_ratean estimate. Good enough to pick a worker; never exact. - Optimistic insert, no confirmation. The tree is updated at selection time (
cache_aware.rs:450) — before the worker even accepts the request. If the request fails or the retry layer (routers/http/router.rs:217) re-routes it, the losing worker keeps phantom affinity until LRU eviction.on_request_completeis explicitly a no-op (cache_aware.rs:507-517— “Could track success rates per worker”). - Independent eviction clocks. Router evicts by char budget on a timer; the worker evicts its real radix tree by token budget whenever decode needs memory (
radix_cache.py:537). A worker under memory pressure can dump exactly the prefixes the router still credits it with — the router will happily keep routing “hits” that are now misses. Nothing reports worker evictions back (contrast Dynamo, below). - Worker restarts / cache flushes.
/flush_cacheexists on both layers, but a worker that restarts behind a stable URL comes back cold while the router’s tree still says hot. - Multi-tenant ambiguity. When several workers legitimately share a prefix,
prefix_match_with_countsreturns the cached/arbitrarylast_tenant(tree.rs:568-599), not the best-loaded owner. - Multi-router deployments. Each router instance learns only from its own traffic. There’s a gossip “mesh” to sync tree ops (
cache_aware.rs:452-463), but the receive path “is not yet wired” (:269-272), so HA router pairs today have divergent trees.
None of these break correctness — worst case is a wasted prefill, the same graceful degradation your sketch had. That’s the design’s actual thesis: affinity is a hint, load is the guardrail, health is the gate.
PD mode and the gRPC pipeline (brief)
In prefill/decode disaggregation the router picks a pair: select_pd_pair (routers/http/pd_router.rs:793-872) runs independent policies over the prefill pool and decode pool (--prefill-policy cache_aware --decode-policy power_of_two is the canonical combo), then injects bootstrap_host/bootstrap_port/bootstrap_room into the JSON body (:228-276) so the decode worker can pull KV from the prefill worker. The gRPC router (routers/grpc/) goes further: it tokenizes in Rust, runs reasoning/tool-call parsing in Rust, and streams token IDs to SRT workers over the proto/ contract — the “industry-first gRPC pipeline” pitch in sgl-model-gateway/README.md. That’s the path where routing can use real token IDs instead of chars.
4.2 RadixAttention: the real prefix cache inside each worker
python/sglang/srt/mem_cache/radix_cache.py is the ground truth the router approximates. Same data structure, different payload — here tree nodes map token-ID keys to KV-cache page indices:
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None
self.lock_ref = 0
self.last_access_time = time.monotonic()
— radix_cache.py:206-211 (TreeNode.__init__, :201). key is a RadixKey (token ids + optional extra_key namespace — LoRA ID, sampling salt — so requests that must not share KV get disjoint subtrees, :337-349); value is a tensor of KV pool slot indices; lock_ref pins nodes that in-flight requests depend on.
Match (match_prefix, :337; helper :622-646) walks children keyed by the first page of tokens, splitting a node when the match ends mid-segment:
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = access_time
prefix_len = child.key.match(key, page_size=self.page_size)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
— radix_cache.py:629-641. The returned indices become req.prefix_indices — those KV slots are simply reused, which is the whole RadixAttention trick: prefill only computes the un-matched suffix.
Insert happens when requests finish (cache_finished_req, :417) or at each chunked-prefill boundary (cache_unfinished_req, :464 — so a 100k-token prompt becomes shareable chunk by chunk, and the request re-matches its own insertion then re-locks the deeper node, :495-516). _insert_helper (:678) returns how many tokens were already present so the duplicate KV slots get freed back to the allocator (:448-451).
Evict (:537) is called by the scheduler whenever the token allocator can’t satisfy a prefill/decode allocation. It heapifies the evictable leaves under a pluggable strategy and frees leaf-by-leaf, cascading to parents:
leaves = list(self.evictable_leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
— radix_cache.py:543-555. Strategies live in mem_cache/evict_policy.py:10-66: LRU (default), LFU, FIFO, MRU, priority-aware, SLRU. Pinning: inc_lock_ref/dec_lock_ref (:566-600) move subtree token counts between evictable_size_ and protected_size_ — the scheduler’s admission math budgets against evictable size only. The RadixAttention layer (python/sglang/srt/layers/radix_attention.py:57) is thin: forward just dispatches q/k/v plus the paged metadata to the active attention backend (:145-153 → get_attn_backend().forward(...)), e.g. layers/attention/flashinfer_backend.py, which is what actually computes attention over the non-contiguous cached pages. Variants you can ignore initially: hiradix_cache.py (hierarchical GPU→CPU→storage tiers), swa_radix_cache.py (sliding-window), mamba_radix_cache.py, radix_cache_cpp.py.
4.3 Scheduler: continuous batching, SGLang-style
python/sglang/srt/managers/scheduler.py is a single-threaded event loop per GPU process. The classic loop is event_loop_normal (:1432): recv requests → build batch → run → process results. The default is event_loop_overlap (:1459), which hides all CPU work (detokenization dispatch, radix-cache updates, next-batch construction) behind the GPU:
# Get the next batch to run
batch = self.get_next_batch_to_run()
self.cur_batch = batch
disable_overlap_for_batch = self.is_disable_overlap_for_batch(batch)
# If we do not need to overlap the current batch with the last batch,
# we can process the last batch immediately.
if disable_overlap_for_batch:
pop_and_process()
# Launch the current batch
if batch:
batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), batch_result))
— scheduler.py:1481-1494. Batch N+1 is scheduled and launched while batch N’s results are still on the GPU; results are popped from result_queue one iteration later (:1499-1501), and sampling of the current batch is deferred until the previous batch’s results are in because grammar masks depend on them (:1506-1509). This is the “zero-overhead batch scheduler” from the v0.4 blog; vLLM v1 later adopted a similar async pipeline, but it’s been SGLang’s default posture for ~two years, and it’s why SGLang’s small-model decode throughput stays GPU-bound.
Batch construction, get_next_batch_to_run (:2426): merge last prefill batch into the running decode batch (:2467-2494), then prefill-first: if get_new_batch_prefill produces a batch it runs, otherwise decode (:2524-2535). Inside prefill admission (_get_new_batch_prefill_raw, :2591):
- The waiting queue is priority-sorted by cache awareness —
self.policy.calc_priority(...)(:2630) with default policylpm(longest prefix match). Inmanagers/schedule_policy.py, every waiting request getsmatch_prefixrun against the radix tree, then:
waiting_queue.sort(
key=lambda r: (
-r.num_matched_prefix_tokens
if r.rid not in temporary_deprioritized
else float("inf")
)
)
— schedule_policy.py:300-306. Requests that reuse the most cache go first (maximizing hit rate and minimizing prefill compute); over 128 queued it degrades to FCFS to cap sort cost (:223-227). The temporary_deprioritized set is in-batch prefix caching (:247-293): if several queued requests share a prefix that isn’t cached yet, schedule one and hold the rest so they hit the cache the first one populates — fleet-level dedup at the scheduler.
- A
PrefillAdder(schedule_policy.py:425, constructed atscheduler.py:2647-2663) admits requests under token budgets (max_prefill_tokens, remaining KV,new_token_ratio— an adaptive estimate of decode growth), chunking oversized prompts (chunked prefill,add_chunked_req) and optionally mixing chunked prefill with decode in one forward (is_mixed_chunk). - Under decode OOM,
update_running_batch(:2864) retracts:batch.retract_decode(...)(:2892) pops the youngest requests out of the running batch back into the waiting queue (their KV freed-or-cached; they later resume by re-matching the radix tree) and tightensnew_token_ratio— admission control learns from the event. The log line to grep is"KV cache pool is full. Retract requests."(:2925-2937).
For multi-GPU: TP ranks each run a scheduler process in lockstep; --dp-size adds a data_parallel_controller.py dispatching to replicas; PD disaggregation (srt/disaggregation/) splits prefill and decode into separate servers with KV transfer (mooncake/nixl backends) — which is exactly what the router’s PD mode fronts.
4.4 Structured output (brief)
Grammar backends live in python/sglang/srt/constrained/; default is xgrammar (server_args.py:3357-3359), alternatives llguidance and outlines. Compilation is async so it never stalls the GPU loop: requests needing a grammar park in GrammarManager.grammar_queue (constrained/grammar_manager.py:29) with a compilation-cache lookup (:111), and the scheduler pulls ready ones back into the waiting queue at the top of every prefill pass (scheduler.py:2594-2598). At decode time each request’s matcher fills a per-row bitmask which is applied to logits on-GPU:
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if logits.device.type in {"cuda", "xpu", "musa"}:
if _is_hip:
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
else:
apply_token_bitmask_inplace_triton(logits, vocab_mask)
— constrained/xgrammar_backend.py:105-117. You already know grammar bitmasks; the SGLang-specific bits are the async compile queue, the reasoner_grammar_backend.py wrapper (mask only after </think>), and the interaction with overlap scheduling (scheduler.py:1537-1546 — spec+grammar forces overlap off for that batch).
4.5 Streaming / detokenization egress
Egress is its own process so UTF-8 assembly and stop-string trimming never block scheduling. DetokenizerManager holds per-request DecodeStatus offsets (managers/detokenizer_manager.py:61-86, LRU-capped at 65536 states, :58) and does incremental decode with explicit handling of byte-incomplete codepoints:
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reasons[i] is None:
# Streaming. Invariant: sent_offset >= decoded_text_len. The
# gap (`pending`) is "printable but uncommitted" text emitted
# in a prior "�" recovery step; we skip it from this step's
# emission so we don't double-send.
pending = s.sent_offset - s.decoded_text_len
if new_text and not new_text.endswith("�"):
# Clean text: commit to decoded_text and advance offsets.
s.append_decoded_text(new_text)
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
s.sent_offset = s.decoded_text_len
output_strs.append(new_text[pending:] if pending else new_text)
— detokenizer_manager.py:347-360. The “surrogate window” re-decodes a small token suffix each step so tokenizer merge effects stay correct without re-decoding the whole sequence. Downstream, TokenizerManager matches outputs to waiting request states and the FastAPI layer streams SSE; per-request cached_tokens (your radix-hit observability) rides along in BatchStrOutput (managers/io_struct.py:1126). Note the symmetry with the gateway: in gRPC mode this entire egress stage is reimplemented in Rust (sgl-model-gateway/src/routers/grpc/).
5. Suggested reading path
Router first — it’s self-contained, idiomatic Rust, and small enough to read fully:
sgl-model-gateway/src/policies/cache_aware.rs:1-60— the design doc is the file header. Thenselect_worker(:376-505).sgl-model-gateway/src/policies/tree.rs—Node/Tree(:229-250),insert(:362),prefix_match_with_counts(:531),evict_tenant_by_size(:718). Compare against the prefix tree you sketched.sgl-model-gateway/src/policies/mod.rs— the policy trait + config; then skimpower_of_two.rsandprefix_hash.rsas contrasting policies.sgl-model-gateway/src/core/worker.rs— health hysteresis (:703-752),WorkerLoadGuard(:1109-1140); skimcircuit_breaker.rs,retry.rs.sgl-model-gateway/src/routers/http/router.rs— request flow end-to-end (route_typed_request,:194); thenpd_router.rs:793for pair selection.python/sglang/srt/mem_cache/radix_cache.py— the real tree (whole file, 802 lines), withevict_policy.pyalongside.python/sglang/srt/managers/schedule_policy.py— LPM sorting + in-batch prefix caching (:133-306), thenPrefillAdder(:425+).python/sglang/srt/managers/scheduler.py—event_loop_overlap(:1459),get_next_batch_to_run(:2426), retraction (:2864-2937). Skimschedule_batch.pyforReqlifecycle.python/sglang/srt/managers/detokenizer_manager.py— egress (:260-383).- Optional depth:
constrained/grammar_manager.py,layers/attention/flashinfer_backend.py,disaggregation/.
6. Connections to your other study repos
- vllm — Same problem, different prefix-cache shape: vLLM v1 prefix caching hashes full blocks (block-granular hits) where SGLang’s radix tree gives token-granular hits and drives scheduling order (LPM) off the tree. vLLM preempts by recompute/swap; SGLang retracts with adaptive
new_token_ratio(scheduler.py:2892). Overlap scheduling is default in SGLang, async-scheduling arrived later in vLLM. Reading both schedulers back-to-back is the single highest-leverage comparison. - nano-vllm — Its ~1200 lines map 1:1 onto
scheduler.py’s skeleton (event_loop_normal→ step; block manager →token_to_kv_pool_allocator). Use it as the decoder ring: every nano concept exists here, buried under production features. - dynamo — The other answer to your routing question. Dynamo’s KV-aware router consumes KV events published by workers (block stored/evicted) into a global indexer, so its view of cache state is near ground truth, plus it factors in active prefill load; cost = worker instrumentation, event bus, and an indexer in the hot path. sgl-router’s tree is inferred from routing history only — zero worker cooperation, works with any backend speaking HTTP, but suffers the staleness modes in §4.1. That’s the central trade: telemetry-driven vs inference-driven affinity. (Dynamo can also front SGLang workers; SGLang emits KV events for that integration —
mem_cache/events.pyand the_record_store_event/_record_remove_eventcalls inradix_cache.py:561,729.) - llm-d — Kubernetes-native, vLLM-first; its inference scheduler scores pods with pluggable scorers (prefix-cache scorer, load scorer) inside the EPP. Its prefix scorer is conceptually sgl-router’s tree rebuilt as a gateway plugin. Compare its scorer-weighting approach with sgl-router’s hard threshold switch.
- gateway-api-inference-extension — Your Envoy home turf: Envoy stays the data plane, and an ext-proc “endpoint picker” (EPP) sets the upstream via header mutation; InferencePool/InferenceModel CRDs make it portable across GKE/Istio/etc. sgl-router instead is the data plane: a standalone Rust proxy owning retries, CBs, rate limiting, SSE, even tokenization — no Envoy anywhere. Trade-offs you’ll recognize immediately: GIE composes with existing mesh/policy infra and keeps LB logic out-of-process (ext-proc RTT per request); sgl-router is one static binary with policy in-process (sub-ms decisions) but reinvents the proxy layer (look at
src/core/retry.rs,src/middleware.rsand judge the reinvention quality yourself). - xgrammar — It’s the default
--grammar-backendhere;constrained/xgrammar_backend.pyis a thin adapter over itsGrammarMatcher+ token bitmask kernels. SGLang co-evolved with xgrammar (same LMSYS orbit), and the jump-forward decoding idea originated in the SGLang paper. - flashinfer — SGLang’s premier NVIDIA attention backend (
layers/attention/flashinfer_backend.py): RadixAttention’s paged, ragged batches are exactly the workload flashinfer’s plan/run API was built for. Theattention_registry.pyshows the full backend zoo and how prefill/decode can use different backends.
7. Tinkering on one RTX 5080 (16 GB)
Caveats first: the runtime is Linux-targeted — on your Windows box run it inside WSL2 (CUDA passthrough) or the lmsysorg/sglang Docker image. The 5080 is Blackwell sm_120: use a recent SGLang + matching CUDA 12.8+ torch wheels (the repo carries sm120-specific kernels, e.g. python/sglang/srt/layers/attention/flash_mla_sm120.py; if an attention backend complains, --attention-backend triton is the safe fallback).
Single worker sanity check (a ~1B model leaves plenty of room in 16 GB):
python -m sglang.launch_server \
--model-path Qwen/Qwen3-0.6B \
--port 30001 --mem-fraction-static 0.35 \
--enable-metrics
curl http://127.0.0.1:30001/generate \
-d '{"text": "The capital of France is", "sampling_params": {"max_new_tokens": 32}}'
mem_fraction_static (server_args.py:434) caps weights+KV pool as a fraction of total VRAM, so two workers fit on one GPU. Watch the startup log for KV Cache is allocated and the per-prefill log line — this is your ground-truth cache observability (managers/scheduler_components/metrics_reporter.py:526-535):
Prefill batch, #new-seq: 1, #new-token: 6, #cached-token: 0, ...
Two workers + router on one GPU. Don’t use sglang_router.launch_server co-launch for this — it assumes one GPU per DP rank (bindings/python/src/sglang_router/launch_server.py:92 sets base_gpu_id = dp_id * tp_size). Launch manually:
# terminal 1 and 2 — two workers, same GPU, ~35% VRAM each
python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B \
--port 30001 --mem-fraction-static 0.35 --enable-metrics
python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B \
--port 30002 --mem-fraction-static 0.35 --enable-metrics
# terminal 3 — the router (pip install sglang-router, or cargo build in sgl-model-gateway/)
python -m sglang_router.launch_router \
--worker-urls http://127.0.0.1:30001 http://127.0.0.1:30002 \
--policy cache_aware \
--cache-threshold 0.3 --balance-abs-threshold 64 --balance-rel-threshold 1.5 \
--eviction-interval 120
Router listens on :30000, Prometheus metrics on :29000 (src/main.rs:270).
An experiment that makes cache-aware routing visible. Build M distinct “tenants” (long shared system prompts) and fire interleaved requests:
python -m sglang.bench_serving --backend sglang \
--host 127.0.0.1 --port 30000 \
--dataset-name generated-shared-prefix \
--gsp-num-groups 8 --gsp-prompts-per-group 32 \
--gsp-system-prompt-len 2048 --gsp-question-len 128 --gsp-output-len 64
(bench_serving.py:2027,2369-2398.) With cache_aware, each prompt group converges onto one worker, so repeats of a group’s 2048-token preamble hit that worker’s radix cache. What to watch:
- Per-worker cache hit rate —
sglang:cache_hit_rategauge on each worker’s/metrics(observability/metrics_collector.py:291-296), or eyeball#cached-tokenvs#new-tokenin the worker logs. Cache-aware should hold this high (~90%+ on repeat traffic); now rerun the bench with--policy round_robinon the router and watch it crater toward ~50% with 2 workers (each group’s prefix gets prefilled everywhere). - Router selection counters — on
:29000/metrics:smg_worker_selection_total{policy="cache_aware",...}per worker (src/observability/metrics.rs:237) shows the affinity skew;smg_worker_requests_active(:225) shows in-flight balance;smg_worker_health(:229) for the health gate. - End-to-end TTFT — bench_serving’s mean TTFT under cache_aware vs round_robin is the headline number (cached prefill skips ~2048 of ~2176 prompt tokens per request).
- The threshold switch in action — rerun with
--balance-abs-threshold 2 --balance-rel-threshold 1.1and high concurrency: the imbalance detector now trips constantly, the router degrades to shortest-queue, selections even out, and the hit rate drops. That knob-feel — affinity vs balance as a tunable dial — is the entire thesis of the design, and you can demo it with two 0.6B workers on your one card.
To inspect the router’s mental model directly, GET http://127.0.0.1:30000/v1/loads returns each worker’s self-reported token load (src/server.rs:607-612), and the router’s debug logs (RUST_LOG=debug) print load-balancing trigger events with per-worker loads (policies/cache_aware.rs:318-326). For staleness experiment #3 from §4.1: hit one worker’s /flush_cache, keep the traffic running, and watch its #cached-token go to zero while the router keeps sending the same groups there — the approximation never notices.
FlashInfer
A study guide to the kernel library that sits at the very bottom of your inference stack.
Repo: materials/flashinfer/flashinfer (github.com/flashinfer-ai/flashinfer). All paths below
are relative to the repo root; line numbers match this checkout.
1. What it is
FlashInfer is the GPU kernel library and kernel generator that inference engines call for the
operations that dominate a serving step: attention (prefill, decode, MLA, cascade/shared-prefix,
sparse), sampling (top-k/top-p/min-p on device), plus GEMM/MoE, RoPE, and norms. It is not an
engine — it has no scheduler, no server, no model code. It exposes PyTorch-callable ops that are
JIT-compiled per configuration (dtype, head_dim, feature flags) and cached as .so files.
Who consumes it:
- SGLang — its standard attention backend on NVIDIA GPUs (
--attention-backend flashinfer) and its on-GPU sampling kernels. - vLLM — optional attention backend (
VLLM_ATTENTION_BACKEND=FLASHINFER), and the channel through which NVIDIA’s trtllm-gen Blackwell kernels reach vLLM. - Also TensorRT-LLM, TGI, MLC-LLM, LightLLM (see
README.md:231-242).
The project’s own CLAUDE.md is an unusually good architecture document — read it alongside this
guide.
2. Why you care
- It defines the per-step cost model. Every latency assumption upstream — scheduler token budgets, router ITL estimates, autoscaler capacity curves — bottoms out at: decode step time ≈ (weight bytes + batch KV bytes) / HBM bandwidth; prefill time ≈ FLOPs / tensor-core throughput. Those constants are set by the kernels in this repo. When SGLang’s scheduler decides how many tokens fit in a step, it is budgeting against these kernels’ behavior.
- The plan/run split is the reason engines look the way they do. Kernels are “planned” on the CPU (workspace layout, work partitioning) once per batch composition, then “run” per layer with zero host decisions — which is what makes CUDA-graph capture, fixed capture batch sizes, and cheap per-layer dispatch possible. It is a literal control-plane/data-plane separation.
- Cascade attention is the kernel-level payoff of prefix sharing. Prefix caching as you know it deduplicates KV storage. The cascade kernels additionally deduplicate KV bandwidth — the shared prefix is read from HBM once for the whole batch instead of once per request. That’s why prefix-aware routing pays off twice.
3. Architecture map
Four layers, with a hard rule: framework-agnostic kernels never see PyTorch headers.
flashinfer/ Python package: public API, one module per op family
decode.py, prefill.py attention wrappers (the plan/run classes)
cascade.py, sampling.py shared-prefix attention; on-GPU sampling
page.py, mla/, pod.py KV-page append, DeepSeek MLA, fused prefill+decode
jit/ JIT machinery: JitSpec, Jinja codegen, ninja builds
core.py, env.py build/load/cache; ~/.cache/flashinfer layout
attention/modules.py gen_*_module() generators + URI naming
csrc/ TVM-FFI bindings + launchers (PyTorch tensor world)
batch_decode.cu Plan() and Run() entry points for decode
*.jinja type-specialization templates rendered at JIT time
include/flashinfer/ header-only CUDA kernel templates (raw pointers only)
page.cuh paged_kv_t — the page table struct kernels consume
attention/decode.cuh CUDA-core decode kernel (bandwidth-bound GEMV)
attention/prefill.cuh tensor-core prefill kernel (compute-bound GEMM)
attention/scheduler.cuh DecodePlan/PrefillPlan — CPU-side work partitioning
attention/cascade.cuh attention-state merge kernels
sampling.cuh sorting-free top-k/top-p rejection sampling
benchmarks/ per-kernel microbenchmarks (no model weights needed)
flashinfer-cubin/, flashinfer-jit-cache/ optional prebuilt-kernel packages
Bindings use TVM-FFI rather than torch extensions, so the same compiled kernels are usable from
non-PyTorch frameworks. 3rdparty/ vendors CUTLASS and spdlog.
4. Core mechanisms
4.1 The plan/run split — control plane vs data plane
Every batch attention API is a wrapper class with two phases. plan() runs on the CPU once per
batch composition (which requests, which lengths); run() executes per layer — 32-80 times per
step — with no host-side decisions. If you squint, plan is route computation and run is packet
forwarding against an installed FIB.
flashinfer/decode.py:901 (BatchDecodeWithPagedKVCacheWrapper.plan) takes the page-table
metadata as host tensors, because the planner needs to read lengths on the CPU, and calls into
C++ BatchDecodeWithPagedKVCachePlan (csrc/batch_decode.cu:39). The real work is
DecodePlan in include/flashinfer/attention/scheduler.cuh:426:
// include/flashinfer/attention/scheduler.cuh:437
FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages,
new_batch_size, gdy, batch_size, indptr_h, num_qo_heads,
page_size, enable_cuda_graph, stream));
size_t padded_batch_size;
plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.split_kv = split_kv;
padded_batch_size =
(enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size;
plan_info.padded_batch_size = padded_batch_size;
auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] =
DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages);
What the plan actually decides:
- Load balancing. Decode at small batch sizes can’t fill the GPU (one CTA per request per KV
head). The planner binary-searches the smallest KV chunk size such that splitting long
sequences into chunks produces enough CTAs to occupy every SM
(
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch,include/flashinfer/attention/scheduler.cuh:74). Classic bin-packing, done ahead of time on the CPU. - Workspace layout. All scratch (partial outputs for split-KV, index arrays) is carved out of
one user-provided 128 MB workspace buffer by an offset allocator — no per-call
cudaMalloc. - One H2D copy. The whole schedule is staged in pinned host memory and shipped with a single
cudaMemcpyAsync(scheduler.cuh:489).
The plan is returned to Python as an opaque vector of int64 offsets (DecodePlanInfo::ToVector,
scheduler.cuh:391). At run time, csrc/batch_decode.cu:175 rehydrates raw pointers from it:
// csrc/batch_decode.cu:175
params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
if (plan_info.split_kv) {
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
}
CUDA-graph tie-in: plan() explicitly “cannot be used in CUDA Graph or torch.compile”
(flashinfer/decode.py:1006), while run() is capturable. Under graphs the grid size must be
fixed, so the planner pads the launch to a fixed padded_batch_size and the kernel early-exits
surplus blocks via block_valid_mask (include/flashinfer/attention/decode.cuh:423). This is
why engines capture graphs at a fixed set of batch sizes and bucket requests into them.
4.2 Paged KV-cache attention: how the page table reaches the kernel
The page table is CSR. Three int32 arrays describe the whole batch:
indices (flat list of physical page ids), indptr ([batch+1], request i owns
indices[indptr[i]:indptr[i+1]]), and last_page_len (fill level of each request’s last page).
The engine (SGLang/vLLM) owns page allocation; FlashInfer only consumes this view. On the C++
side it becomes a 30-line struct passed by value into the kernel:
// include/flashinfer/page.cuh:37
template <typename DType, typename IdType>
struct paged_kv_t {
uint_fastdiv page_size;
uint32_t num_heads, head_dim, batch_size;
uint32_t stride_page, stride_n, stride_h;
DType* k_data; // base pointer of the whole page pool
DType* v_data;
IdType* indices; // physical page ids
IdType* indptr; // [batch_size + 1] per-request extent into indices
IdType* last_page_len; // [batch_size] entries used in the final page
IdType* rope_pos_offset;
...
};
Inside the decode kernel, threads translate logical token positions to physical addresses — a
software page-table walk. uint_fastdiv precomputes the magic numbers so the hot loop never does
an integer divide:
// include/flashinfer/attention/decode.cuh:480
uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start;
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
The decode main loop (decode.cuh:523-585) is a textbook bandwidth machine: a multi-stage
cp_async pipeline streams K tiles and V tiles from HBM into shared memory while the previous
tile is being dotted against the single query vector held in registers:
// include/flashinfer/attention/decode.cuh:523 (elided)
for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) {
cp_async::wait_group<2 * num_stages_smem - 1>(); // wait for oldest K tile
compute_qk<...>(params, variant, batch_idx, k_smem + ..., q_vec, ...); // q . k
cp_async::pred_load<...>(k_smem + ..., paged_kv.k_data + kv_offset[j], ...); // prefetch next K
cp_async::commit_group();
cp_async::wait_group<2 * num_stages_smem - 1>(); // wait for V tile
update_local_state<...>(v_smem + ..., s, stage_idx, st, tx); // softmax-weighted V accumulate
cp_async::pred_load<...>(v_smem + ..., paged_kv.v_data + kv_offset[j], ...); // prefetch next V
cp_async::commit_group();
}
Every KV byte is read exactly once; the arithmetic per byte is a handful of FMAs. That is the
whole “decode is bandwidth-bound” story in code: the kernel’s job is to keep the memory pipes
full, and the planner’s job (4.1) is to make sure there are enough CTAs for that. When a sequence
was split, partial results carry (output, log-sum-exp) and a second kernel merges them — the
same merge operator cascade attention uses (4.4).
4.3 Prefill vs decode kernel families
Two separate kernel templates exist because the two phases sit on opposite ends of the roofline:
- Decode (
include/flashinfer/attention/decode.cuh): one query token per request. q·K is a GEMV — O(1) FLOPs per KV byte — so it runs on plain CUDA cores with vectorized loads, as above. Tensor cores would sit idle. - Prefill (
include/flashinfer/attention/prefill.cuh): many query tokens attend to many keys. q·K is a GEMM — O(qo_len) FLOPs per KV byte — so the kernel is built around 16x16x16 tensor-core fragments (mma= matrix-multiply-accumulate instruction):
// include/flashinfer/attention/prefill.cuh:946
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
if constexpr (std::is_same_v<typename KTraits::DTypeQKAccum, float>) {
if (mma_d == 0) {
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ, MMAMode::kInit>(
s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
} else {
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ>(
s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
}
}
...
}
The GQA twist that blurs the line: with grouped-query attention, 4-8 query heads share each KV
head, so even a “decode” step has a small matrix (group_size x head_dim) per KV head. Pass
use_tensor_cores=True to the decode wrapper and it silently routes through the prefill
kernel family (flashinfer/decode.py:1181-1213 calls get_batch_prefill_module) — same page
table, different execution engine. Mixed chunked-prefill+decode batches get dedicated paths:
POD-attention (csrc/batch_pod.cu) fuses both phases into one kernel, and the newer holistic
persistent scheduler lives in csrc/batch_attention.cu /
include/flashinfer/attention/persistent.cuh.
For your mental model: a serving step’s cost = one compute-bound prefill kernel over the new tokens + one bandwidth-bound decode kernel over all resident KV. Chunked prefill exists upstream precisely to interleave these two regimes without starving either resource.
4.4 Cascade attention: the kernel payoff of prefix sharing
FlashInfer generalizes attention into an attention state (v(I), s(I)) — the weighted value
and log-sum-exp over an index set I — with an associative, commutative merge operator
(docs/tutorials/recursive_attention.rst). Attention over I ∪ J equals merging the states of
I and J, computed in any order, on any device. The merge kernel is just a numerically-stable
softmax renormalization:
// include/flashinfer/attention/cascade.cuh:98 (MergeStateInPlaceKernel)
float s_val = s[pos * num_heads + head_idx];
float s_other_val = s_other[pos * num_heads + head_idx];
float s_max = max(s_val, s_other_val);
s_val = math::ptx_exp2(s_val - s_max);
s_other_val = math::ptx_exp2(s_other_val - s_max);
float scale = s_val / (s_val + s_other_val);
float other_scale = s_other_val / (s_val + s_other_val);
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = scale * v_vec[i] + other_scale * v_other_vec[i];
}
MultiLevelCascadeAttentionWrapper (flashinfer/cascade.py:226) exploits this for shared
prefixes: level 0 treats the entire batch as one query block attending to the shared prefix’s
KV; lower levels handle each request’s unique suffix; then states are merged:
# flashinfer/cascade.py:547
out, lse = self._batch_prefill_wrappers[-1].run(q, paged_kv_cache, return_lse=True)
for wrapper in self._batch_prefill_wrappers[:-1]:
out_i, lse_i = wrapper.run(q, paged_kv_cache, return_lse=True)
merge_state_in_place(out, lse, out_i, lse_i)
return out
Why it’s fast: in the level-0 pass, all batch queries attend the prefix together (it’s matrix-shaped, tensor-core friendly), and the prefix KV is fetched from HBM once for the whole batch instead of once per request. SGLang’s radix tree shares prefix pages (capacity win); cascade shares prefix traffic (bandwidth win). Same split-then-merge machinery also powers split-KV decode (4.2) and, conceptually, any cross-device attention sharding.
4.5 JIT compilation: kernels generated and cached at runtime
The kernel space is combinatorial — dtype(q) x dtype(kv) x dtype(o) x head_dim x position encoding x sliding window x logits soft-cap x backend x GPU arch. AOT-compiling all of it is infeasible, so FlashInfer compiles on first use. Each combination gets a deterministic URI:
# flashinfer/jit/attention/modules.py:78
return (
f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
A gen_*_module() function (e.g. gen_batch_decode_module,
flashinfer/jit/attention/modules.py:915) renders a Jinja template
(csrc/batch_decode_customize_config.jinja) into a type-specialized config header, copies the
.cu sources, and returns a JitSpec. Building means: emit build.ninja, run nvcc, load the
.so through TVM-FFI — with an AOT short-circuit:
# flashinfer/jit/core.py:307
def build_and_load(self):
if self.is_aot:
return self.load(self.aot_path)
# Guard both build and load with the same lock to avoid race condition
with FileLock(self.lock_path, thread_local=False):
so_path = self.jit_library_path
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
self.build(verbose, need_lock=False)
result = self.load(so_path)
return result
Caching is two-level: @functools.cache in-process (see get_sampling_module,
flashinfer/sampling.py:66) and on disk at ~/.cache/flashinfer/<version>/<archs>/cached_ops/
(flashinfer/jit/env.py:135-150 — version and arch list are part of the path, so upgrades and
new GPUs invalidate cleanly). Recompilation triggers: a never-seen parameter combination, source
file hash change, compiler flag change, arch list change, version change.
Ops note for the layers you run: the first request with a new model shape stalls for seconds-to-
minutes of nvcc — this is why engines warm up before serving traffic, why flashinfer-jit-cache
and flashinfer-cubin wheels exist (prebuilt .sos / downloaded trtllm-gen cubins), and why
FLASHINFER_DISABLE_JIT=1 (flashinfer/jit/core.py:290) is used in locked-down prod images to
fail fast instead of compiling.
4.6 Sampling: selection never leaves the GPU
After attention produces logits, token selection (temperature, top-k, top-p, min-p) runs in a single fused kernel per batch — no sort of the 128k-entry vocab, no host round-trip. The trick is rejection sampling with a shrinking pivot window: sample a token from the full distribution, then check in one block-wide reduction whether the top-p constraint would have admitted it; if not, tighten the admissible probability range and resample. One CTA per batch row; the do/while converges in a few rounds:
// include/flashinfer/sampling.cuh:1070 (TopPSamplingFromProbKernel, inside do { ... })
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1;
}
} while (low < high);
(pivot_0 is the sampled token’s probability, pivot_1 a midpoint — two candidate thresholds
evaluated per pass over the vocab, so each round halves the search interval.) Python entry points:
top_p_sampling_from_probs (flashinfer/sampling.py:976), top_k_top_p_sampling_from_probs
(flashinfer/sampling.py:1579), plus chain_speculative_sampling for draft-token verification in
speculative decoding. Bindings live in csrc/sampling.cu / csrc/renorm.cu. The end-to-end
consequence you already know from the serving side: the only thing crossing PCIe per step is a
handful of int32 token ids.
5. Suggested reading path
README.md+docs/tutorials/kv_layout.rst— vocabulary: NHD/HND layouts, ragged vs paged, indptr conventions. Thendocs/tutorials/recursive_attention.rst(10 minutes, the merge operator).flashinfer/decode.py:712-900—BatchDecodeWithPagedKVCacheWrapper.__init__: workspace buffers, CUDA-graph buffer pinning. This class is the contract engines program against.flashinfer/decode.py:901-1281(plan) and:1352-1660(run) — note what is host metadata vs device tensors, and what gets cached on the wrapper.csrc/batch_decode.cu(200 lines, read in full) — Python tensors becomepaged_kv_t+ raw pointers; plan info becomes workspace offsets.include/flashinfer/attention/scheduler.cuh:60-130and:366-492— the planner. This is the most “your world” file in the repo: occupancy estimation, binary-search load balancing, offset allocation.include/flashinfer/page.cuh:37-210— the page table struct and address translation.include/flashinfer/attention/decode.cuh:380-608— one full bandwidth-bound kernel. Skim for the pipeline structure, not the index arithmetic.include/flashinfer/attention/prefill.cuh— skimcompute_qk(:878) just to see the mma fragments; thenflashinfer/cascade.py:226-560+include/flashinfer/attention/cascade.cuh.flashinfer/sampling.py+include/flashinfer/sampling.cuh:968-1090.flashinfer/jit/core.py+flashinfer/jit/attention/modules.py:915-960, then look inside~/.cache/flashinfer/after running anything.benchmarks/bench_batch_decode.py— 100 lines that tie the whole cost model together.
End-to-end trace to cement it: one wrapper.run(q, kv_cache) decode call —
decode.py:1352 → csrc/batch_decode.cu:83 → BatchDecodeWithPagedKVCacheDispatched
(include/flashinfer/attention/decode.cuh:741) → BatchDecodeWithPagedKVCacheKernel
(decode.cuh:613) → optional VariableLengthMergeStates (cascade.cuh:687).
6. Connections to your other study repos
- sglang — primary consumer.
python/sglang/srt/layers/attention/flashinfer_backend.pywraps these exact wrapper classes; itsinit_forward_metadatais whereplan()gets called each scheduler step, and SGLang’s “capture batch sizes” map to the CUDA-graph padding you saw in 4.1. Its sampler callsflashinfer.sampling.*. SGLang’s radix cache decides what pages are shared; FlashInfer decides how cheaply they’re read. - vllm — optional attention backend (
VLLM_ATTENTION_BACKEND=FLASHINFER); on Blackwell it is also the delivery vehicle for trtllm-gen attention cubins. Comparing vLLM’s FlashInfer backend wrapper with SGLang’s is a good exercise: same plan/run contract, different metadata plumbing. - nano-vllm — uses flash-attn instead of FlashInfer; reading its attention call next to SGLang’s FlashInfer backend shows exactly what the plan/run + page-table contract buys (and costs) versus a single monolithic kernel call.
- dynamo / llm-d / gateway-api-inference-extension — never touch these kernels, but every latency model they embed (TTFT vs ITL tradeoffs, KV-cache-aware routing, saturation detection) is a function of the two regimes in 4.3. When a router predicts “adding one decode request costs ~X ms/step,” X is the marginal KV bytes divided by bandwidth — i.e., section 4.2.
- xgrammar — its
apply_token_bitmask_inplaceCUDA kernel is a tiny cousin of the sampling kernels here: same “mutate logits on device, never sync to host” principle, applied just before FlashInfer’s sampling kernel draws the token.
7. Hands-on with your RTX 5080 (16 GB)
Platform honesty: WSL2, not native Windows. There are no Windows wheels, CI is Linux, the JIT
pipeline (ninja + nvcc + .so + TVM-FFI load_module) targets Linux shared objects; the stray
sys.platform == "win32" branch in flashinfer/autotuner.py:90 is not a support statement. Use
WSL2 Ubuntu + CUDA toolkit 12.8+ (SM 12.0 needs it). Clone inside the WSL filesystem (~/), not
under /mnt/c/ — JIT compilation hammers the filesystem.
Your GPU is supported. RTX 5080 = Blackwell SM 12.0, listed in README.md:73. The fa2
decode/prefill paths, cascade, and sampling all work; Hopper-only fa3 (sm90a) and B200-class MLA
(sm100a) kernels won’t be available — flashinfer.utils.is_sm120a_supported is the gate tests
use. Setup:
# inside WSL2 Ubuntu with CUDA 12.8+ toolkit
git clone https://github.com/flashinfer-ai/flashinfer.git --recursive && cd flashinfer
pip install --no-build-isolation -e . -v
export FLASHINFER_CUDA_ARCH_LIST="12.0f" # skip detecting/compiling other archs
flashinfer show-config
Microbenchmark 1 — one attention call (and one JIT compile). Run the README quickstart
(README.md:123-133): flashinfer.single_decode_with_kv_cache(q, k, v) with a 2048-token KV.
First call stalls while nvcc builds the kernel — watch ~/.cache/flashinfer/ populate
(generated/ sources, cached_ops/*.so). Second call is microseconds. That stall is the JIT
story of 4.5 made visceral.
Microbenchmark 2 — the decode bandwidth sweep. benchmarks/bench_batch_decode.py already
measures the exact thing you study:
# benchmarks/bench_batch_decode.py:70-78
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
print(f"execution time: {ms}ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
Trim the sweep loops at the bottom (:81-95) for 16 GB — the largest configs allocate
batch x seq x 2048 bytes of KV (batch 512 x 16384 tokens ≈ 17 GB, OOM; cap at batch ≤ 128 or
seq ≤ 8192). What you should observe: at batch 1 the achieved GB/s is far below spec (not enough
CTAs — this is what split-KV planning fixes); as batch x seq grows, GB/s climbs and plateaus near
the 5080’s ~960 GB/s GDDR7 ceiling while latency becomes a clean linear function of total KV
bytes. That plateau is the per-token decode cost model your schedulers assume. Bonus run:
switch kv_dtype to torch.float8_e4m3fn and watch step time halve — quantized KV as a
bandwidth lever, with zero scheduler changes upstream.
Sanity math for the plateau: per token per layer, KV bytes = 2(K+V) x num_kv_heads x head_dim x dtype_size = 2 x 4 x 128 x 2 B = 2 KB in this benchmark’s config; batch 64 x seq 4096 → ~537 MB per call → ~0.6 ms at 960 GB/s. Your measurement should land within ~20% of that.
If you can’t run it locally: the next-best path is (1) the FlashInfer paper, arXiv
2501.01005; (2) the two blog posts — flashinfer.ai/2024/02/02/cascade-inference.html and
flashinfer.ai/2025/03/10/sampling.html — which are the prose versions of 4.4 and 4.6; (3)
docs/tutorials/kv_layout.rst and recursive_attention.rst in-repo; (4) read
benchmarks/bench_batch_decode.py and predict its output from the napkin math above — the
prediction exercise teaches the cost model even without the GPU.
XGrammar
A study guide to the XGrammar repo checked out at materials/xgrammar/xgrammar (MLC’s structured-generation engine, v0.2.2 — i.e. “XGrammar-2”). Written for someone who already understands the constrained-decoding loop and wants to know where each piece lives in the code.
1. What it is and where it sits
XGrammar is the CPU-side brain of structured output for most serving stacks. It takes a structure spec (JSON Schema, regex, EBNF, or “structural tags” for tool calling), compiles it against a specific tokenizer vocabulary, and then — per request, per decode step — answers one question: which token IDs are legal next? The answer comes back as a packed bitmask that the engine applies to logits on the GPU before sampling.
It is the default structured-output backend in vLLM, SGLang, TensorRT-LLM, and MLC-LLM. The library itself is a C++17 core (cpp/) with thin Python bindings (python/xgrammar/, bound via apache-tvm-ffi), plus a handful of GPU kernels for the one step that touches the GPU. Everything else — compilation, state tracking, mask filling — runs on the host CPU.
One historical note that matters for reading the code: XGrammar v1 (the MLSys ’25 paper) used a persistent-stack pushdown automaton. XGrammar-2 (this code, the 2026 paper) replaced the PDA with an Earley parser over per-rule FSMs. Several docstrings still say “pushdown automaton” (include/xgrammar/matcher.h:43, python/xgrammar/matcher.py:205) — the actual engine is cpp/earley_parser.cc. Your PDA mental model still works: the Earley state sets are the nondeterministic stack configurations, with rule_start_pos playing the role of the stack back-pointer.
2. Why you care
You already derived the loop. Here is the same loop, named with this codebase’s vocabulary:
| Your concept | XGrammar name | Where |
|---|---|---|
| CPU-side automaton holding parse state | GrammarMatcher wrapping an EarleyParser | cpp/grammar_matcher.cc, cpp/earley_parser.cc |
| Per-step token bitmask emitted by CPU | FillNextTokenBitmask | cpp/grammar_matcher.cc:809 |
| GPU applies mask to logits before sampling | apply_token_bitmask_inplace (Triton/CUDA/CPU) | python/xgrammar/kernels/ |
| Sampled token returns to advance the automaton | AcceptToken | cpp/grammar_matcher.cc:610 |
| Context-independent vs context-dependent token split | AdaptiveTokenMask: accepted/rejected (precomputed) vs uncertain (runtime) | cpp/compiled_grammar_impl.h:40 |
| Precompute-vs-on-the-fly tradeoff | GrammarCompiler + per-ParserState mask cache | cpp/grammar_compiler.cc:1048 |
| Tokens straddling grammar boundaries | the “uncertain” class + lookahead assertions | cpp/grammar_compiler.cc:641 |
| Jump-ahead / forced tokens | FindJumpForwardString | cpp/grammar_matcher.cc:964 |
| Rollback for speculative decoding | Rollback, TraverseDraftTree | cpp/grammar_matcher.cc:1023, :55 |
The interesting engineering is all in two places: how much of the mask can be precomputed per automaton state (compile time), and how cheaply the rest can be resolved per step (runtime). Everything below maps those two.
3. Architecture map
xgrammar/
├── include/xgrammar/ # public C++ API (5 headers: grammar, compiler, matcher, tokenizer_info)
├── cpp/ # the core
│ ├── grammar_parser.cc # EBNF text -> grammar IR
│ ├── json_schema_converter.cc # JSON Schema -> EBNF string (3200 lines; the biggest file)
│ ├── regex_converter.cc # regex -> EBNF
│ ├── structural_tag.cc # tool-call "structural tags" -> grammar (XGrammar-2 agentic feature)
│ ├── grammar_functor.cc # grammar passes: normalize, inline, dead-rule elim, build per-rule FSMs (GrammarOptimizer at grammar_functor.h:409)
│ ├── fsm.cc / fsm_builder.cc # FSM data structure; regex->FSM and trie->FSM builders
│ ├── grammar_compiler.cc # adaptive token-mask precompute + grammar/rule caches
│ ├── compiled_grammar_impl.h # CompiledGrammar = grammar + tokenizer + mask cache; AdaptiveTokenMask struct
│ ├── earley_parser.{h,cc} # the runtime automaton (scan/predict/complete)
│ ├── grammar_matcher.cc # GrammarMatcher: AcceptToken, FillNextTokenBitmask, jump-forward, rollback
│ ├── tokenizer_info.cc # vocab decoding (byte-fallback/byte-level), sorted vocab, prefix-trie ranges
│ └── support/ # dynamic_bitset.h (the mask), thread_safe_cache.h (LRU), thread_pool.h
├── python/xgrammar/ # bindings mirror the C++ classes 1:1
│ ├── grammar.py, compiler.py, matcher.py, tokenizer_info.py
│ ├── kernels/ # apply_token_bitmask_inplace_{triton,cuda,cpu,torch,torch_compile,mlx}
│ └── testing.py # debug helpers (_get_masked_tokens_from_bitmask, etc.)
├── docs/tutorials/ # engine_integration.md is the integration contract
├── examples/benchmark/ # bench_grammar_compile_mask_gen.py — compile + mask-fill latency
└── tests/python/ # extensive; test_grammar_matcher_*.py show the API in anger
4. Core mechanisms
4.1 Grammar compilation: JSON Schema → EBNF → IR → FSMs
Every input format funnels into EBNF text, then into a grammar IR, then per-rule FSMs:
// cpp/grammar.cc:43
Grammar Grammar::FromJSONSchema(
const std::string& schema, bool any_whitespace, std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators, bool strict_mode,
std::optional<int> max_whitespace_cnt, bool print_converted_ebnf
) {
auto ebnf_string =
JSONSchemaToEBNF(schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt);
...
return FromEBNF(ebnf_string);
}
JSONSchemaToEBNF (cpp/json_schema_converter.cc:3078) walks the schema and emits grammar rules. The IR node types are in cpp/grammar_impl.h:108 — kByteString, kCharacterClass(Star), kRuleRef, kSequence, kChoices, kTagDispatch, kRepeat, plus the XGrammar-2 token-level edges kToken/kExcludeToken (grammar positions that match whole token IDs, not bytes — used for tool-call dispatch). GrammarOptimizer::Apply (cpp/grammar_functor.h:409) then normalizes, inlines, and builds one FSM per rule (per_rule_fsms).
Where the compilation cost lives. MultiThreadCompileGrammar (cpp/grammar_compiler.cc:1048) enumerates every reachable, scanable FSM state of every rule and precomputes a token mask for each, fanned out over a thread pool:
// cpp/grammar_compiler.cc:1109
for (int32_t rule_id = 0; rule_id < ...NumRules(); ++rule_id) {
auto rule = compiled_grammar_impl->grammar->GetRule(rule_id);
const auto& rule_fsm = compiled_grammar_impl->grammar->per_rule_fsms[rule_id];
auto cur_stack_element =
ParserState(rule_id, rule.body_expr_id, 0, ParserState::kNoPrevInputPos, 0);
std::unordered_set<int> reachable_states;
rule_fsm->GetFsm().GetReachableStates(&reachable_states);
for (int i : reachable_states) {
cur_stack_element.element_id = i;
if (!rule_fsm->GetFsm().IsScanableState(i)) continue;
add_task_adaptive_token_mask(cur_stack_element, rule_id == root_rule_id);
}
}
Each task walks a big chunk of the ~100k-token vocabulary against the parser — this is the O(states × vocab) precompute you predicted. It is amortized by two caches:
- Grammar-level LRU cache —
ThreadSafeLRUCache<UnionKey, CompiledGrammar, ...> grammar_level_cache_(cpp/grammar_compiler.cc:1400, impl incpp/support/thread_safe_cache.h), keyed by the schema/EBNF/regex string plus options (cpp/grammar_compiler.cc:1452). Same schema twice → free. - Rule-level “crossing cache” (XGrammar-2) — per-rule FSMs are structurally hashed (
GrammarFSMHasher,cpp/grammar_functor.h:426) and masks are reused across different grammars that share sub-rules (cpp/grammar_compiler.cc:794-833). Two different tool schemas that both contain a JSON string rule share that rule’s mask work.
CompiledGrammar is serializable (include/xgrammar/compiler.h:39) so engines can persist compilation across processes.
4.2 The adaptive token mask cache — your context-independent/dependent split
This is the data structure to remember. For each automaton state, the vocab is partitioned exactly the way you derived it:
// cpp/compiled_grammar_impl.h:40
struct AdaptiveTokenMask {
enum class StoreType {
// Only store all accepted token indices ... useful when |accepted| < |rejected|.
kAccepted = 0,
// Only store all rejected token indices ... useful when |accepted| > |rejected|.
kRejected = 1,
// Store all accepted token indices in a bitset, when both are large.
kAcceptedBitset = 2
};
StoreType store_type;
static constexpr int USE_BITSET_THRESHOLD = 1000;
std::vector<int32_t> accepted_indices;
std::vector<int32_t> rejected_indices;
DynamicBitset accepted_bitset;
std::vector<int32_t> uncertain_indices;
};
The header comment (cpp/compiled_grammar_impl.h:28-39) states it directly: accepted = decidable by this state alone, rejected = ditto, uncertain = “need the state of the parent ParserStates to determine if acceptable.” Accepted+rejected are your context-independent tokens; uncertain are context-dependent. StoreType is the “adaptive” part — store whichever side is smaller, fall back to a bitset above 1000 entries.
The compiled grammar is literally a hash map from automaton state to this struct (cpp/compiled_grammar_impl.h:119):
std::unordered_map<ParserState, AdaptiveTokenMask, StateHashForCache> adaptive_token_mask_cache;
and the cache key deliberately ignores the input-position field, which is what makes the precompute per-automaton-state rather than per-parse-configuration:
// cpp/earley_parser.h:154
class StateHashForCache {
public:
size_t operator()(const ParserState& state) const {
return HashCombine(state.rule_id, state.sequence_id, state.element_id, state.sub_element_id);
}
};
The classification itself happens in GetAdaptiveTokenMask (cpp/grammar_compiler.cc:780) → GetTokenMaskWithFirstCharacterCheck. For each token (in sorted-vocab order), feed its bytes to a parser seeded with only this rule’s state and no parent context:
// cpp/grammar_compiler.cc:637 (inside the per-token loop)
bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
if (accepted) {
tmp_accepted_indices_.push_back(i);
} else if (can_reach_end && prev_matched_size > 0) {
auto [lookahead_accepted, lookahead_completed] =
IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_);
if ((!is_root_rule) && lookahead_accepted) {
...
tmp_uncertain_indices_.push_back(i);
...
}
} else {
tmp_rejected_indices_.push_back(i);
}
Read that middle branch carefully — it is your “token straddles a grammar boundary” case. The token’s bytes were not fully consumed by this rule, but at some byte the rule could have completed (can_reach_end): whether the remaining bytes are legal depends on whatever rule the parent stack continues with. At compile time that’s unknowable → uncertain. Two refinements cut the uncertain set: for the root rule there is no parent, so nothing is uncertain (is_root_rule), and XGrammar-2’s lookahead assertions (the rule’s right context, when statically known) let some straddling tokens be resolved to accepted/rejected at compile time after all (IsTokenPassLookaheadAssertion, cpp/grammar_compiler.cc:312).
4.3 GrammarMatcher: runtime state and the two core calls
The runtime automaton state is a set of ParserStates (cpp/earley_parser.h:41 — rule_id, sequence_id, element_id within the rule FSM, rule_start_pos back-pointer, sub_element_id for mid-UTF-8 positions). The parser keeps full history (scanable_state_history_, a Compact2DArray) so any number of tokens can be rolled back — that’s what makes jump-forward and speculative decoding cheap. Advancing by one byte is classic Earley scan/predict/complete:
// cpp/earley_parser.cc:257
bool EarleyParser::Advance(const uint8_t ch, bool debug_print) {
...
const auto& latest_states = scanable_state_history_[scanable_state_history_.size() - 1];
for (const auto& state : latest_states) Scan(state, ch);
if (tmp_process_state_queue_.empty() && tmp_states_to_be_added_.empty()) return false;
while (!tmp_process_state_queue_.empty()) {
const auto state = std::move(tmp_process_state_queue_.front());
tmp_process_state_queue_.pop();
auto [scanable, completable] = Predict(state, debug_print);
if (completable) Complete(state, debug_print);
if (scanable) tmp_states_to_be_added_.push_back(state);
}
is_completed_.push_back(tmp_accept_stop_token_);
scanable_state_history_.PushBack(tmp_states_to_be_added_);
return true;
}
Accepting a sampled token (GrammarMatcher::Impl::AcceptToken, cpp/grammar_matcher.cc:610) tries two paths and merges: an atomic path over kToken/kExcludeToken edges (the token matched as a unit), and the byte path (the token’s decoded bytes fed through Advance one at a time, cpp/grammar_matcher.cc:668-676). Stop tokens are only accepted when IsCompleted() (:634-641); rejection pops any partially-advanced bytes so the state is untouched.
Filling the mask (FillNextTokenBitmask, cpp/grammar_matcher.cc:809) is where compile-time and runtime meet. For every live Earley state, look up its precomputed AdaptiveTokenMask; OR together the accepted sets, AND together the rejected sets (nondeterminism = union of possibilities):
// cpp/grammar_matcher.cc:842
for (const auto& state : latest_states) {
auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
const auto& adaptive_token_mask = adaptive_token_mask_it->second;
...
if (adaptive_token_mask.store_type == StoreType::kAcceptedBitset) {
tmp_accepted_bitset_ |= adaptive_token_mask.accepted_bitset;
} else if (adaptive_token_mask.store_type == StoreType::kAccepted) {
for (auto idx : adaptive_token_mask.accepted_indices)
tmp_accepted_bitset_.Set(sorted_decoded_vocab[idx].first, true);
}
}
Then the uncertain tokens are resolved on the fly by actually stepping the live parser through their bytes (cpp/grammar_matcher.cc:880-940) — this is the only per-step work proportional to grammar context, and it’s bounded by the uncertain count, not the vocab size. Finally SetTokenBitmask (cpp/grammar_matcher.cc:1035) writes the bitset, turning stop tokens on only if IsCompleted() — that’s how “end of structure → only EOS legal” falls out. The function returns need_apply=false when the mask is all-ones, letting engines skip the GPU apply entirely (python/xgrammar/matcher.py:334).
4.4 Tokens vs grammar alignment: bytes, sorted vocab, trie ranges
The grammar operates on bytes, and the vocab is normalized to bytes up front. TokenizerInfo decodes each tokenizer entry through ByteFallbackDecoder (<0x1B> → raw byte 0x1B) or ByteLevelDecoder (inverse of GPT-2 bytes-to-unicode) at cpp/tokenizer_info.cc:47 and :83, so a token is just a byte string and partial UTF-8 inside a token is handled by sub_element_id/partial_codepoint in ParserState.
Matching 100k token strings per state is made tractable by sorting the decoded vocab lexicographically and precomputing, for each token, the half-open range of tokens that have it as a prefix — a flattened prefix trie:
// cpp/tokenizer_info.cc:306
// The value means: the subtree is [i, trie_subtree_nodes_range[i]).
trie_subtree_nodes_range_.resize(sorted_decoded_vocab_.size(), 0);
std::stack<std::pair<std::string, int32_t>> prefix_stack;
for (size_t i = 0; i < sorted_decoded_vocab_.size(); ++i) {
const auto& token = sorted_decoded_vocab_[i].second;
while ((!prefix_stack.empty()) && (token.find(prefix_stack.top().first) == std::string::npos)) {
trie_subtree_nodes_range_[top_pair.second] = i;
prefix_stack.pop();
}
prefix_stack.push({token, i});
}
Both the compile-time classifier and the runtime uncertain-resolver exploit this: consecutive sorted tokens share prefixes, so the parser only advances/pops the differing suffix (the std::mismatch LCP logic at cpp/grammar_compiler.cc:596 and cpp/grammar_matcher.cc:901), and when a prefix is rejected the entire subtree range is skipped (subtree_range[cur_token_idx], cpp/grammar_matcher.cc:888). Note the masks store sorted-vocab indices, translated to token IDs only when writing the final bitmask.
4.5 Mask application on GPU
The bitmask layout is ceil(vocab_size/32) int32 words per row (GetBitmaskSize → DynamicBitset::GetBufferSize, cpp/grammar_matcher.cc:36; shape helper at python/xgrammar/matcher.py:19). Bit=1 means allowed; bit=0 means set the logit to −inf. For a 128k vocab that’s 16 KB per request — small enough that engines copy it CPU→GPU every step. The default GPU path is Triton:
# python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py:64
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
indices lets a mixed batch apply masks only to the structured requests. The CUDA twin lives in python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu (BITS_PER_BLOCK = 32 at :34, vectorized LogitsBitmaskKernel at :64, JIT-compiled through torch.utils.cpp_extension). Backend dispatch — cpu / triton / cuda / torch_compile / torch_native / mlx — is in python/xgrammar/matcher.py:58-188. There is also a pure-C++ CPU apply (ApplyTokenBitmaskInplaceCPU, include/xgrammar/matcher.h:32).
4.6 Jump-forward / forced tokens
When the grammar forces output (e.g. {"name": " after {), the engine shouldn’t pay a forward pass per forced character. FindJumpForwardString walks the live states’ FSM edges and extends the result while every state agrees on exactly one possible next byte:
// cpp/grammar_matcher.cc:984 (inside the while loop)
int next_char = -1;
for (const auto& state : states) {
const auto& fsm = grammar_->per_rule_fsms[state.rule_id].value();
const auto& current_edges = fsm.GetFsm().GetFsm().GetEdges(state.element_id);
for (const auto& edge : current_edges) {
if (!edge.IsCharRange()) continue;
if (edge.min != edge.max) { can_find_next_char = false; break; }
if (next_char == -1) next_char = edge.min;
else if (next_char != edge.min) { can_find_next_char = false; break; }
}
}
// 2. If found, accept the char and iterate; finally PopLastStates(num_accepted_chars)
It advances the parser to find the next forced byte, then rolls everything back (PopLastStates) — state is unchanged. The engine retokenizes the returned string and force-appends; that retokenization can shift token boundaries, which is why Rollback (cpp/grammar_matcher.cc:1023) exists. SGLang uses this (try_jump_forward); vLLM v1 currently only references it in a comment (backend_xgrammar.py:137). The related single-token shortcut _IsSingleTokenBitmask (include/xgrammar/matcher.h:30) lets an engine skip sampling when the mask allows exactly one token.
4.7 Speculative decoding hook
TraverseDraftTree (include/xgrammar/matcher.h:130, impl cpp/grammar_matcher.cc:55-110) DFS-walks a draft token tree, calling AcceptToken/FillNextTokenBitmask/Rollback at each node to produce one mask per draft position — the grammar-side counterpart of tree-based spec decode. Unlimited rollback is what makes this and jump-forward safe (python/xgrammar/matcher.py:230: “the new Earley parser … allows unlimited rollback”).
4.8 Engine integration surface
The contract an engine implements per decode step (see docs/tutorials/engine_integration.md):
# per model: tokenizer_info = xgr.TokenizerInfo.from_huggingface(tok, vocab_size=config.vocab_size)
# compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) # holds the LRU cache
# per request: matcher = xgr.GrammarMatcher(compiler.compile_json_schema(schema))
# per step:
matcher.fill_next_token_bitmask(token_bitmask, i) # CPU, one row per request
xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) # GPU
# after sampling:
assert matcher.accept_token(next_token_id)
BatchGrammarMatcher (cpp/grammar_matcher.cc:1117) parallelizes mask fills across requests with an internal thread pool — fill time is per-request CPU work on the critical path between forward pass and sampling, so engines overlap or thread it.
In your local checkouts:
- vLLM (
materials/vllm/vllm): backend atvllm/v1/structured_output/backend_xgrammar.py(fill_bitmask→matcher.fill_next_token_bitmaskat:191,allocate_token_bitmaskat:124). Orchestration invllm/v1/structured_output/__init__.py— aStructuredOutputManagerthat runs fills on aThreadPoolExecutoronce the structured batch exceeds 128 requests (:62-64). The bitmask travels scheduler→worker, and the GPU apply is vLLM’s own Triton kernel,_apply_grammar_bitmask_kernelinvllm/v1/worker/gpu/structured_outputs.py:86— same bit layout, reimplemented engine-side. - SGLang (
materials/sglang/sglang): backend atpython/sglang/srt/constrained/xgrammar_backend.py—accept_token(:79),fill_vocab_mask(:105),apply_vocab_maskdispatching to xgrammar’s kernels (:112), and jump-forward wired up viatry_jump_forward(:144). Selected among backends inbase_grammar_backend.py/grammar_manager.pyin the same directory.
5. Suggested reading path
docs/tutorials/engine_integration.md— the whole loop in 60 lines of Python; confirms your model end-to-end.include/xgrammar/matcher.h+include/xgrammar/compiler.h— the entire public API surface, ~370 lines.cpp/compiled_grammar_impl.h—AdaptiveTokenMaskand the state→mask map; the central data structure.cpp/earley_parser.hthencpp/earley_parser.cc:257(Advance) — the runtime automaton; skim Scan/Predict/Complete.cpp/grammar_matcher.cc—AcceptToken(:610),FillNextTokenBitmask(:809),FindJumpForwardString(:964),SetTokenBitmask(:1035).cpp/grammar_compiler.cc—MultiThreadCompileGrammar(:1048),GetAdaptiveTokenMask(:780), the classification loop (:594-665), and the two cache layers (:1394-1400).cpp/tokenizer_info.cc:266-323— vocab decode, sort, trie ranges; plus the byte decoders at:27-100.python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py— 130 lines; then compare vLLM’s copy invllm/v1/worker/gpu/structured_outputs.py.
Optional ninth: cpp/json_schema_converter.cc (only if you want schema→EBNF details) and tests/python/test_grammar_matcher_basic.py for executable documentation.
6. Connections to your other study repos
- vllm — consumer. Integration tier:
vllm/v1/structured_output/(manager + xgrammar/guidance/outlines backends behindbackend_types.py), GPU apply invllm/v1/worker/gpu/structured_outputs.py. Note the division of labor: xgrammar fills masks in the scheduler process; the worker applies them — the bitmask is part of scheduler→worker IPC. - sglang — consumer.
python/sglang/srt/constrained/mirrors the same backend abstraction and additionally exercises jump-forward. SGLang historically drove the jump-forward idea (its Outlines backend hasoutlines_jump_forward.py). - nano-vllm — no structured output at all. Adding it is a clean exercise: (1) hold a
GrammarMatcherper sequence innanovllm/engine/scheduler.py’s sequence state, (2) innanovllm/engine/model_runner.py, fill a bitmask for grammar-bearing sequences andapply_token_bitmask_inplaceon the logits beforenanovllm/layers/sampler.pyruns, (3) callaccept_tokenon the sampled IDs, reset/free on finish. ~100 lines; the subtleties are batching (indices for mixed batches) and not blocking the step on compile (vLLM compiles asynchronously and treats the grammar as “not ready” until done). - flashinfer — sibling, not a dependency. Same architectural slot as xgrammar’s kernels: small device ops bracketing sampling. In a vLLM step the order is xgrammar bitmask apply → flashinfer/torch sampling kernels. No xgrammar code in the flashinfer repo.
- dynamo — mostly unrelated (routing/disaggregation tier). Its only touchpoint: passing
guided_decoding_backend: xgrammar|llguidancethrough to TensorRT-LLM (components/src/dynamo/trtllm/backend_args.py:231). One traffic-relevant nuance: structured requests change per-step cost (CPU mask fill on the critical path, compile spikes on first sight of a schema), which a router could in principle account for — none of these do today. - llm-d, gateway-api-inference-extension — honestly unrelated. Structured output is opaque request payload by the time it reaches the gateway/EPP tier; nothing to find there.
7. Hands-on on your machine (RTX 5080, 16 GB, Windows)
pip install xgrammar ships prebuilt wheels including Windows. Caveat: the Triton dependency is Linux-x86_64-only (pyproject.toml:22), and the cuda backend JIT-compiles via nvcc+MSVC — on native Windows, GPU mask apply falls back to torch_compile. That doesn’t matter: everything interesting here is CPU-only. The GPU never sees the grammar, only 16 KB of bits. (For the Triton path, use WSL2.)
Experiment 1: watch the mask evolve (no GPU)
import time, xgrammar as xgr
from transformers import AutoTokenizer
from xgrammar.testing import _get_masked_tokens_from_bitmask
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
info = xgr.TokenizerInfo.from_huggingface(tok)
compiler = xgr.GrammarCompiler(info, max_threads=8)
schema = '{"type":"object","properties":{"name":{"type":"string"},"port":{"type":"integer"}},"required":["name","port"]}'
t0 = time.perf_counter()
cg = compiler.compile_json_schema(schema)
print(f"compile {(time.perf_counter()-t0)*1e3:.1f} ms, mask cache {cg.memory_size_bytes/2**20:.2f} MiB")
matcher, bitmask = xgr.GrammarMatcher(cg), xgr.allocate_token_bitmask(1, info.vocab_size)
for tid in tok.encode('{"name": "edge-proxy", "port": 8443}', add_special_tokens=False):
t0 = time.perf_counter(); matcher.fill_next_token_bitmask(bitmask)
us = (time.perf_counter()-t0)*1e6
allowed = info.vocab_size - len(_get_masked_tokens_from_bitmask(bitmask, info.vocab_size))
print(f"fill {us:7.1f} us allowed {allowed:6d} jump={matcher.find_jump_forward_string()!r} next={tok.decode([tid])!r}")
assert matcher.accept_token(tid)
Run it twice — the second compile_json_schema returns in microseconds (LRU cache hit). The jump= column shows forced text ({"name": " etc.) shrinking to '' whenever the model has a real choice. Add debug_print=True to fill_next_token_bitmask/accept_token, or call matcher._debug_print_internal_state(), to see the live Earley states.
Experiment 2: mask-fill latency on a gnarly schema
Fill cost scales with live states × uncertain tokens, so stress it: deep nesting, anyOf fans, and regex-pattern strings (format: email/uuid style patterns produce character-class-heavy automata where few tokens are decidable per state). Compare against the trivial schema above; also try compiler.compile_builtin_json_grammar() (free-form JSON — highly recursive) and a pathological EBNF of your own via compiler.compile_grammar(ebnf). Time three things separately: first compile (cold), fill_next_token_bitmask per step along a long valid output (use matcher.accept_string(...) to fast-forward), and BatchGrammarMatcher.batch_fill_next_token_bitmask with 64 forked matchers (matcher.fork()) to see the threaded path engines actually use. The repo’s own harness for exactly this comparison is examples/benchmark/bench_grammar_compile_mask_gen.py (xgrammar vs outlines vs llguidance) and examples/benchmark/bench_apply_token_bitmask_inplace.py for the kernel side.
Experiment 3 (GPU, optional): end-to-end with a real model
The HF integration in python/xgrammar/contrib/hf.py provides a LogitsProcessor; with a 0.5–3B model on the 5080 you can run docs/tutorials/workflow_of_xgrammar.md’s generation loop verbatim and verify the per-step overhead is dominated by the CPU fill, not the mask apply. On native Windows pass backend="torch_compile" to apply_token_bitmask_inplace (or run under WSL2 for Triton).
NVIDIA Dynamo
1. What it is
Dynamo is NVIDIA’s open-source, datacenter-scale inference serving framework. It does not run models itself — it is the orchestration layer above inference engines (vLLM, SGLang, TensorRT-LLM), turning a fleet of single-node engines into one coordinated serving system:
- OpenAI-compatible frontend (Rust, axum) with preprocessing (chat templates, tokenization)
- KV-aware router that scores workers by cached-prefix overlap plus live load
- Disaggregated prefill/decode with GPU-to-GPU KV-cache transfer via NIXL
- KV Block Manager (KVBM) for multi-tier KV offload (GPU → pinned host → disk → S3)
- Planner for SLA-driven autoscaling of prefill/decode pools
- A distributed runtime (service discovery, leases, request plane, event plane) that all of the above is built on
The mental model: engines optimize a box; Dynamo optimizes the fleet. If you serve one model on one GPU, you don’t need it. If you serve across nodes and want to never recompute a prefix that already exists somewhere in the cluster, this is the layer that does it.
Version in this snapshot: workspace 1.3.0 (Cargo.toml:36). Core is Rust (edition 2024,
tokio); engine adapters and the planner are Python; everything meets at PyO3 bindings.
2. Why you care
This repo is the closest thing in the ecosystem to “your current job, but for LLM inference”:
- It is traffic infrastructure in Rust. The hot path — HTTP ingress, routing decision, request dispatch, token streaming — is all tokio/axum/tonic code. The same vocabulary you use at the edge (connection pools, pub/sub planes, leases, backpressure, softmax-with- temperature load balancing, sticky sessions, request migration) appears here with one new twist: the load balancer has to reason about where bytes of KV cache already live.
- Routing is the product. A request router whose cost function is
prefill_blocks_you'd_recompute + decode_blocks_in_flightis cache-aware load balancing — the same family of problem as consistent-hash + least-loaded at Discord, but the cache is GPU VRAM and a miss costs a multi-second prefill. - There is literally an Envoy ext_proc server in here.
deploy/inference-gateway/ext-procis a Rust gRPC ext_proc implementation of the Gateway API Inference Extension’s Endpoint Picker, backed by the samedynamo-kv-routercrate. Your Envoy background maps directly. - Token egress is a designed transport, not an afterthought: a call-home TCP response plane, separate from the request plane, separate from the event plane.
3. Architecture map
The Rust workspace
Workspace members are declared in Cargo.toml:4-32. The crates that matter, roughly by layer:
| Crate | Path | What it does |
|---|---|---|
dynamo-runtime | lib/runtime | The distributed runtime: DistributedRuntime, Namespace/Component/Endpoint model, discovery backends (K8s, etcd, file, memory), etcd leases, NATS/TCP/ZMQ transports, the pipeline graph (Source → Operator → Sink), push router, metrics |
dynamo-llm | lib/llm | Everything LLM-specific: OpenAI/Anthropic HTTP service, preprocessor (chat template + tokenize), KvPushRouter, PrefillRouter (disagg orchestration), migration, model discovery/cards, KVBM block manager v2, gRPC (KServe) frontend, engines glue |
dynamo-kv-router | lib/kv-router | The routing core, runtime-free: radix-tree indexers (single-thread, concurrent, compressed), scheduling (cost function, queue, admission), active-sequence tracking, replica sync, ZMQ wire format |
dynamo-tokens | lib/tokens | Token sequences, block hashing helpers |
dynamo-kv-hashing | lib/kv-hashing | “Request → PositionalLineageHash” contract so router, KVBM, and engines agree on block identity |
dynamo-protocols | lib/protocols | OpenAI-compatible API types with NVIDIA extensions (nvext) |
dynamo-tokenizers / dynamo-renderer | lib/tokenizers, lib/renderer | HF tokenizers and minijinja chat-template rendering, runtime-free |
dynamo-parsers | lib/parsers | Tool-calling and reasoning-trace parsers |
dynamo-memory + kvbm-* | lib/memory, lib/kvbm-* | Memory arenas, NIXL agent bindings (lib/memory/src/nixl/agent.rs), and the KV Block Manager split into logical/physical/engine/kernels crates |
dynamo-mocker | lib/mocker | GPU-free simulated engine: vLLM/SGLang-style schedulers, KV events, disagg — drives the real router/frontend without hardware |
dynamo-backend-common | lib/backend-common | Shared glue for writing pure-Rust backends |
dynamo-ext-proc | deploy/inference-gateway/ext-proc | Envoy ext_proc gRPC server (GAIE Endpoint Picker) reusing dynamo-kv-router |
Python bindings and components
lib/bindings/pythonis a PyO3 crate exposing the runtime and LLM stack asdynamo._core(lib/bindings/python/rust/lib.rs). Python getsDistributedRuntime,Endpoint,KvEventPublisher,register_model,make_engine, the planner hooks, etc.components/src/dynamo/is the Python product surface:frontend(thin wrapper that launches the Rust HTTP server + router),vllm/sglang/trtllm(engine workers),planner,mocker,router(standalone router process),global_router,profiler.
How engines plug in as workers
An engine worker is just a Python process that (1) builds a DistributedRuntime, (2) creates
an endpoint like dyn://namespace.component.generate, (3) registers a model card, and
(4) serves an async generator. From the vLLM component:
# components/src/dynamo/vllm/worker_factory.py:196-230 (abridged)
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
...
await register_model(
ModelInput.Tokens, ModelType.Empty, generate_endpoint, config.model,
model_name=config.served_model_name or config.model,
worker_type=WorkerType.Encode,
needs=[[WorkerType.Prefill, WorkerType.Decode], [WorkerType.Aggregated]],
)
...
await generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
)
register_model publishes a Model Deployment Card into discovery; the frontend’s
ModelWatcher (lib/llm/src/discovery/watcher.rs) sees it and wires a new pipeline. The
engine never speaks HTTP — it receives already-tokenized PreprocessedRequests and yields
token deltas. The handler’s generate(request, context) is the entire engine contract.
4. Core mechanisms
4.1 Request flow end to end
Standalone-mode flow: client → Frontend (HTTP) → [preprocess → migration → detokenize → PrefillRouter → KvPushRouter] → worker, with tokens streaming back through the same chain.
The whole thing is one typed pipeline built from operators with forward and backward edges:
#![allow(unused)]
fn main() {
// lib/llm/src/entrypoint/input/common.rs:352-362
let engine = frontend
.link(preprocessor_op.forward_edge())?
.link(migration.forward_edge())?
.link(token_backend.forward_edge())?
.link(prefill_op.forward_edge())?
.link(backend)?
.link(prefill_op.backward_edge())?
.link(token_backend.backward_edge())?
.link(migration.backward_edge())?
.link(preprocessor_op.backward_edge())?
.link(frontend)?;
}
Forward edges transform the request (OpenAI JSON → templated prompt → token IDs); backward
edges transform the response stream (token IDs → detokenized deltas → OpenAI SSE chunks).
Migration sits high enough to replay/redirect an in-flight request if a worker dies.
Transport choices — three independent planes, which is the part worth studying closely:
- Client-facing egress is axum SSE. The engine’s response stream is mapped through metrics observers and a disconnect monitor, then wrapped:
#![allow(unused)]
fn main() {
// lib/llm/src/http/service/openai.rs:760-768
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}
Ok(sse_stream.into_response())
}
- Request plane (frontend → worker RPC) defaults to direct TCP, selectable to NATS via
DYN_REQUEST_PLANE(docs/design-docs/request-plane.md). Routing mode aside, the actual dispatch isPushRouter::direct(request, instance_id)(lib/runtime/src/pipeline/network/egress/push_router.rs). - Response plane is the part that will feel familiar from designing token-egress paths:
the requester runs a
TcpStreamServer(lib/runtime/src/pipeline/network/tcp/server.rs:87-93), and the request carries connection info; the worker calls home over a fresh TCP stream (CallHomeHandshake, same file) and frames response chunks back with a two-part codec. So request fan-out and token fan-in are decoupled — tokens never transit NATS even when the request did. Frame inactivity timeouts, prologue handshakes, and per-stream contexts for cancellation are all explicit (lib/runtime/src/pipeline/network/egress/push_router.rs:52-62).
There is also a separate event plane (NATS Core or ZMQ, DYN_EVENT_PLANE) carrying KV
events and router replica sync — covered next. A KServe gRPC frontend exists alongside HTTP
(lib/llm/src/grpc), and lib/llm/src/http/service/realtime.rs covers websocket realtime.
4.2 KV-aware routing
The router answers: given these tokens, which worker would do the least new work? Three pieces: block hashing, a global prefix index fed by worker events, and a cost function that fuses cache overlap with live load.
Block hashing. Token sequences are chunked into engine-sized blocks (16/64 tokens) and hashed with XXH3; LoRA adapter identity is mixed into the seed so adapters don’t cross-hit:
#![allow(unused)]
fn main() {
// lib/kv-router/src/protocols.rs:85-97
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
options: BlockHashOptions<'_>,
) -> Vec<LocalBlockHash> {
if kv_block_size == 0 {
return Vec::new();
}
let seed = match options.lora_name.filter(|n| !n.is_empty()) {
Some(name) => XXH3_SEED.wrapping_add(xxh3::xxh3_64(name.as_bytes())),
None => XXH3_SEED,
};
}
These are local hashes computed from token content, so the router never depends on engines
agreeing on block IDs (docs/design-docs/router-design.md, “Deterministic Event IDs”).
KV events → indexer. Engines publish stored/removed/cleared events as blocks enter and
leave their paged KV caches (vLLM via its ZMQ KvEventPublisher, the mocker natively):
#![allow(unused)]
fn main() {
// lib/kv-router/src/protocols.rs:640-646
pub enum KvCacheEventData {
Stored(KvCacheStoreData),
Removed(KvCacheRemoveData),
Cleared,
}
}
The KvIndexer folds these into a radix tree over block hashes where every node records
which workers hold that prefix. A lookup walks the query’s hash sequence and returns
per-worker matched depth:
#![allow(unused)]
fn main() {
// lib/kv-router/src/indexer/radix_tree.rs:200-206
pub fn find_matches(&self, sequence: Vec<LocalBlockHash>, early_exit: bool) -> OverlapScores {
self.find_match_details(sequence, early_exit).overlap_scores
}
pub fn apply_event(&mut self, event: RouterEvent) -> Result<(), KvCacheEventError> {
self.apply_event_with_counters(event, None)
}
}
OverlapScores is just FxHashMap<WorkerWithDpRank, u32> (lib/kv-router/src/protocols.rs:942-947).
Implementations range from a single-threaded compressed tree to a sharded
ConcurrentRadixTree with sticky per-worker write threads (lib/kv-router/src/indexer/).
Event-loss handling is sequence-number based: each worker numbers its events monotonically,
the router detects gaps and re-queries the worker’s local indexer for the missing range,
and dumps full worker state on discovery — or you can opt into NATS JetStream durable
consumption with snapshots in object store (--durable-kv-events).
Scoring. Selection happens in DefaultWorkerSelector — the cost (“logit”, lower is
better) per worker is computed in worker_logit:
#![allow(unused)]
fn main() {
// lib/kv-router/src/scheduling/selector.rs:194-203
let effective_overlap_score_credit = weights.overlap_score_credit * overlap_credit_decay;
let overlap_credit_blocks = effective_overlap_score_credit * device_overlap_blocks
+ self.kv_router_config.host_cache_hit_weight * host_overlap_blocks
+ self.kv_router_config.disk_cache_hit_weight * disk_overlap_blocks
+ shared_overlap_blocks;
let adjusted_prefill_blocks = (raw_prefill_blocks - overlap_credit_blocks).max(0.0);
let prefill_cost_blocks = weights.prefill_load_scale * adjusted_prefill_blocks;
let worker_load = worker_load.unwrap_or_default();
let decode_cost_blocks = worker_load.potential_decode_blocks() as f64;
let logit = prefill_cost_blocks + decode_cost_blocks;
}
Read it as: blocks I’d have to prefill from scratch (discounted by what the worker already
holds on device, in pinned host memory, on disk, or in a shared external cache — each tier
with its own credit weight) plus blocks already decoding there. An overlap_credit_decay
term softly forfeits cache affinity on workers whose prefill backlog exceeds the fleet floor
— the explicit TTFT-vs-ITL tradeoff knob. Selection is min-cost with reservoir-sampled tie
breaking, or softmax sampling when router_temperature > 0
(lib/kv-router/src/scheduling/selector.rs:29-85). Taints implement topology/required-zone
constraints, multiplying scores for preferred taints and filtering for required ones.
Load signals come from the router’s own bookkeeping, not engine polling: a slot manager
(ActiveSequencesMultiWorker, lib/kv-router/src/sequences/) predicts active blocks at
route time, marks prefill complete on first output token, and frees on stream end. With
multiple router replicas, these predictions sync over NATS core
(AddRequest / MarkPrefillCompleted / Free, lib/kv-router/src/sequences/replica_sync.rs).
Admission is serialized through SchedulerQueue::admit_one so projected load and booking
can’t race (lib/kv-router/src/scheduling/CLAUDE.md documents the invariants — worth reading
as a design doc). The KV-routing wrapper that ties selection to dispatch:
#![allow(unused)]
fn main() {
// lib/llm/src/kv_router/push_router.rs:321-335 (abridged)
let mut response_stream = cancel_on_stop(
request_context.as_ref(),
&context_id,
self.inner
.direct(updated_request, instance_id)
.instrument(tracing::info_span!(
"kv_router.route_request",
request_id = %context_id,
worker_id = instance_id,
overlap_blocks = overlap_amount,
)),
)
.await??;
}
KvRouter itself (lib/llm/src/kv_router.rs:186-205) is deliberately “decide, don’t route”:
it owns the Indexer + KvScheduler and emits FindBestMatchOutcome::{Routed, Backpressure}
— backpressure carries queued-token depth so the HTTP layer can 429/503 instead of piling on.
4.3 Disaggregated prefill/decode
Disaggregation is orchestrated entirely at the routing layer by PrefillRouter
(lib/llm/src/kv_router/prefill_router/mod.rs:36-43), an operator sandwiched between
migration and the decode router (see the pipeline chain above). The split:
- Clone the preprocessed request, force
max_tokens = 1, route it to a prefill worker (KV-aware, using the same selector with prefill-specific config):
#![allow(unused)]
fn main() {
// lib/llm/src/kv_router/prefill_router/mod.rs:114-116
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
}
- The prefill worker computes KV and returns transfer metadata, not KV bytes. In the
vLLM handler the metadata is vLLM’s
kv_transfer_params(block IDs + connection info):
# components/src/dynamo/vllm/handlers.py:3641-3651 (abridged)
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"disaggregated_params": self._build_disaggregated_params(
kv_protocol.decode_request_kv_transfer_params(res),
embedding_params,
),
"completion_usage": ...,
}
PrefillRouterinjects that metadata into the decode request, restores the originalmax_tokens, and forwards to the decode router:
#![allow(unused)]
fn main() {
// lib/llm/src/kv_router/prefill_router/mod.rs:316-335 (abridged)
match outcome {
PrefillOutcome::Bootstrap { bootstrap_info, worker_id } => {
decode_req.bootstrap_info = Some(bootstrap_info);
decode_req.routing_mut().prefill_worker_id = Some(worker_id);
}
PrefillOutcome::Completed { result, worker_id, worker_link } => {
decode_req.prefill_result = Some(result);
decode_req.migration_link = worker_link;
...
}
};
}
- The decode worker uses the metadata to pull KV directly from the prefill worker’s VRAM
via NIXL (NVLink / InfiniBand-UCX / PCIe), non-blocking with respect to its forward
passes. The transfer itself runs inside the engines’ KV-connector layer; Dynamo’s own
Rust NIXL agent (
lib/memory/src/nixl/agent.rs) is used by KVBM and multimodal RDMA.
Two execution shapes (docs/design-docs/disagg-serving.md): SGLang-style bootstrap —
prefill worker publishes an RDMA rendezvous endpoint, so prefill is spawned as a background
task and decode routing proceeds immediately, overlapping transfer with decode scheduling;
vLLM/TRT-LLM-style synchronous — decode waits for the prefill response. Decode-side
routing then runs with an override that zeroes overlap credit and prompt-load tracking
(build_decode_router_override, asserted at prefill_router/mod.rs:453-463), because the
decode pool’s cache state was just mutated by the transfer, not by prefix reuse.
Operationally notable details for someone who has run this class of system:
- Prefill is deliberately not linked as a child of the request context: cancelling it
mid-NIXL-transfer would permanently leak KV blocks, so it runs to completion and wasted
compute is the accepted tradeoff (
prefill_router/mod.rs:153-162— comment cites the bug). - Backpressure from the prefill queue surfaces as
ResourceExhaustedrather than silently re-entering the saturated queue (prefill_router/mod.rs:181-203). - Fallback is explicit: no prefill workers → aggregated mode, unless
--enforce-disaggfails fast; all-prefill-death flips adeactivatedflag checked per request (prefill_router/mod.rs:96-104). - Topology constraints (“decode must be NVLink-reachable from this prefill worker”) are
merged into the decode request as required/preferred taints
(
merge_decode_topology_constraints,prefill_router/mod.rs:424-442).
4.4 The distributed runtime
Everything above runs on dynamo-runtime. The object model is a three-level hierarchy —
Namespace / Component / Endpoint (lib/runtime/src/component.rs:4-30) — addressed as
dyn://namespace.component.endpoint. An Instance is one live process serving an endpoint,
identified by a u64 instance_id and a transport address
(lib/runtime/src/component.rs:106-115). The hello-world server is the whole pattern:
#![allow(unused)]
fn main() {
// lib/runtime/examples/hello_world/src/bin/server.rs:54-65
async fn backend(runtime: DistributedRuntime) -> anyhow::Result<()> {
// attach an ingress to an engine
let ingress = Ingress::for_engine(RequestHandler::new())?;
let component = runtime.namespace(DEFAULT_NAMESPACE)?.component("backend")?;
component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
}
}
Discovery is pluggable (lib/runtime/src/distributed.rs:148-184): a Discovery trait
with two families — Kubernetes-native (CRD DynamoWorkerMetadata + EndpointSlices, no etcd)
and KV-store backed (etcd, a plain filesystem directory for laptop dev via
--discovery-backend file, or in-memory for tests). Workers write instance records and
model cards; frontends list_and_watch a DiscoveryQuery and react to Added/Removed.
Liveness in the etcd backend is lease-based — every registration is attached to a TTL lease kept alive by a background task; if keep-alive fails the process’s cancellation token fires, taking the whole worker down rather than serving as a zombie:
#![allow(unused)]
fn main() {
// lib/runtime/src/transports/etcd/lease.rs:15-27
pub async fn create_lease(
connector: Arc<Connector>,
ttl: u64,
token: CancellationToken,
) -> anyhow::Result<u64> {
let mut lease_client = connector.get_client().lease_client();
let lease = lease_client.grant(ttl as i64, None).await?;
let id = lease.id() as u64;
let ttl = lease.ttl() as u64;
let child = token.child_token();
}
Lease IDs double as instance IDs, which is why the TCP layer keeps short-lived tombstones
keyed on them during re-registration races (lib/runtime/src/pipeline/network/tcp/server.rs:15-19).
NATS is optional everywhere: requests default to TCP, discovery defaults to K8s/file,
and KV events can ride ZMQ. NATS earns its place for JetStream-durable KV events, router
replica sync, and brokered request plane (README.md:246-262).
4.5 Planner / autoscaling
components/src/dynamo/planner is the SLA autoscaler (Python). Signals are forward-pass
metrics (FPM) scraped from workers — queued prefill tokens, KV-cache utilization, observed
TTFT/ITL — fed into a state machine with two modes: simple threshold scaling and SLA mode
backed by per-GPU performance models (interpolated from offline profiling or AIConfigurator
predictions). The non-SLA thresholds are legible enough to read directly:
# components/src/dynamo/planner/core/load_scaling.py:24-34
# Prefill: ratio of queued_prefill_tokens / context_length
_PREFILL_THROUGHPUT_SCALE_UP = 1.0 # queued >= context_length
_PREFILL_THROUGHPUT_SCALE_DOWN = 0.1 # queued < context_length / 10
_PREFILL_LATENCY_SCALE_UP = 0.1 # queued >= context_length / 10
_PREFILL_LATENCY_SCALE_DOWN = 0.0 # queued == 0
# Decode/Agg: KV cache utilization (scheduled + queued) / max_kv_tokens
_DECODE_THROUGHPUT_SCALE_UP = 1.0 # util > 100%
_DECODE_THROUGHPUT_SCALE_DOWN = 0.6 # util < 60%
_DECODE_LATENCY_SCALE_UP = 0.4 # util > 40%
_DECODE_LATENCY_SCALE_DOWN = 0.1 # util < 10%
Prefill and decode pools scale independently (the economic argument for disaggregation),
with budget caps and connectors that execute decisions against Kubernetes (DGD resources via
the operator) or virtual targets for simulation. The DynamoGraphDeploymentRequest CRD
(README.md:162-175) chains AIConfigurator profiling → planner topology → deployment from a
single model + SLA spec. A global_planner component coordinates across deployments.
5. Suggested reading path
- Orientation (30 min).
README.md, thendocs/design-docs/architecture.mdanddocs/design-docs/dynamo-flow.md(the numbered S1-S9 disagg walkthrough). - Runtime primitives.
lib/runtime/examples/hello_world/src/bin/server.rsandclient.rs(74 lines total — the whole component/endpoint/push-router model), thenlib/runtime/src/component.rsdoc comments andlib/runtime/src/distributed.rs:117-184for discovery backends. Skimlib/runtime/src/pipeline/network/tcp/server.rsfor the call-home response stream. - The frontend pipeline.
lib/llm/src/entrypoint/input/common.rs:323-365(operator chain),lib/llm/src/http/service/openai.rs(SSE handler), andlib/llm/src/preprocessor.rs(what “preprocess” actually does). - The router, bottom-up.
lib/kv-router/src/protocols.rs(hashes, events,OverlapScores),lib/kv-router/src/indexer/radix_tree.rs(single-threaded tree first — the concurrent versions are optimizations of the same shape),lib/kv-router/src/scheduling/selector.rs(cost function; tests at the bottom are executable documentation),lib/kv-router/src/scheduling/CLAUDE.md(invariants), then up tolib/llm/src/kv_router.rsandlib/llm/src/kv_router/push_router.rsfor the integration.docs/design-docs/router-design.mdalongside. - Disaggregation.
docs/design-docs/disagg-serving.md, thenlib/llm/src/kv_router/prefill_router/mod.rstop to bottom (thegeneratemethod is the whole story), withcomponents/src/dynamo/vllm/handlers.pyprefill/decode handlers as the engine-side counterpart. - Worker lifecycle in Python.
components/src/dynamo/vllm/main.py+worker_factory.py, andlib/bindings/python/rust/lib.rsto see how thin the binding is. - Optional depth. KVBM (
lib/llm/src/block_manager.md,lib/kvbm-logical,lib/memory/src/nixl/), planner (components/src/dynamo/planner/core/), the Envoy EPP (deploy/inference-gateway/ext-proc/src/picker.rs,epp.rs).
6. Connections to your other study repos
- vllm / sglang — the engines Dynamo orchestrates. What you studied as their internals
(paged KV, prefix caching, continuous batching) becomes Dynamo’s signal surface: vLLM’s
ZMQ KV-event publisher and
kv_transfer_params, SGLang’s RDMA bootstrap, both wrapped bycomponents/src/dynamo/{vllm,sglang}. Dynamo’s mocker reimplements their schedulers in Rust (lib/mocker) — a compact second take on what you read in those codebases. - sgl-router (in the sglang repo) — the same cache-aware-routing idea, one tier, one
binary: an approximate radix tree built from the requests it routed, no event plane.
Dynamo’s router is the heavyweight sibling: event-fed exact index, multi-replica state
sync, tiered cache credits, prefill/decode phase awareness, admission queue. Comparing
sgl-router’s tree tolib/kv-router/src/indexer/is a great compare-and-contrast. - llm-d — the K8s-native competitor (Google/Red Hat/IBM orbit). Same thesis (disagg +
cache-aware scheduling above vLLM) but architected as Kubernetes: Go, Inference Gateway
EPP as the router, vLLM-only. Dynamo is engine-agnostic, Rust, and owns its own runtime,
with K8s as one discovery backend among several — contrast llm-d’s gateway-resident scorer
with Dynamo’s in-frontend
KvPushRouterplus optional EPP. - gateway-api-inference-extension — the standard llm-d builds on, and Dynamo meets it
too:
deploy/inference-gateway/ext-procis a Rust ext_proc Endpoint Picker that mirrors the Go LW-EPP interface (picker.rs:5-8cites GAIE #2834) but scores withdynamo-kv-router. In gateway mode the frontend runs--router-mode directand honors the EPP’sx-prefill-instance-id/worker headers (README.md:111-122) — directly relevant to your Envoy work. - nano-vllm — the minimal engine; useful here as the mental model for what a Dynamo
worker does between receiving
PreprocessedRequestand yielding token deltas. - xgrammar — structured-output constraints; surfaces in Dynamo only at the edges
(
lib/parsersparses tool calls; guided decoding is delegated to engines vianvext). - flashinfer — two layers below: engine kernels. Dynamo never touches it, but KVBM’s
layout/transfer code (
lib/kvbm-kernels,lib/memory) is where Dynamo gets closest to that altitude.
7. Tinkering on your machine
Windows host: plan on WSL2. Published wheels are manylinux-only — pip on Windows will
not resolve ai-dynamo (docs/reference/release-artifacts.md), and the from-source path
assumes Ubuntu packages (libhwloc-dev, protobuf-compiler; README.md:191-204). Your
RTX 5080 (16 GB, Blackwell) works under WSL2 Ubuntu 24.04 with a CUDA 13 driver; small
models (Qwen3-0.6B, Llama-3.2-1B/3B) leave room to spare. NIXL GPU-to-GPU disagg across two
real workers is not happening on one consumer GPU — use the mocker for that topology.
Tier 0 — no GPU, no Linux strictly required (read/test the Rust core). The routing and protocol crates are pure Rust with no CUDA or engine dependency:
cargo test -p dynamo-kv-router # radix tree, selector, queue, sequences
cargo test -p dynamo-tokens -p dynamo-kv-hashing -p dynamo-protocols
cargo test -p dynamo-mocker # simulated vLLM/SGLang schedulers
cargo bench -p dynamo-llm --bench kv_router_bench --features kv-router-stress
The selector tests (lib/kv-router/src/scheduling/selector.rs:465-1383) are the fastest way
to internalize the cost function — modify a weight, predict the winner, run. Inside WSL2 the
full workspace including dynamo-runtime and dynamo-llm builds and tests CPU-only.
Tier 1 — full serving stack, zero GPUs (mocker). This exercises the real frontend,
preprocessor, KV router, and even disaggregated flow against a simulated engine
(docs/dynosim/mocker.md):
python -m dynamo.frontend --http-port 8000 --router-mode kv &
python -m dynamo.mocker --model-path Qwen/Qwen3-0.6B --num-workers 4 --speedup-ratio 10
# disagg without GPUs:
python -m dynamo.mocker --model-path Qwen/Qwen3-0.6B --disaggregation-mode prefill --bootstrap-ports 50100 &
python -m dynamo.mocker --model-path Qwen/Qwen3-0.6B --disaggregation-mode decode
Watch router decisions with RUST_LOG=dynamo_llm::kv_router=debug — the selector logs the
full cost decomposition per worker (selector.rs:219-229). This is the highest
learning-per-watt setup in the repo for someone who cares about routing.
Tier 2 — one real GPU (WSL2). The README quick start, no etcd/NATS needed:
uv pip install --prerelease=allow "ai-dynamo[vllm]"
python -m dynamo.frontend --http-port 8000 --discovery-backend file &
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --discovery-backend file \
--kv-events-config '{"enable_kv_cache_events": false}'
curl -s localhost:8000/v1/chat/completions -H 'Content-Type: application/json' -d \
'{"model":"Qwen/Qwen3-0.6B","messages":[{"role":"user","content":"hi"}],"max_tokens":50,"stream":true}'
With ~16 GB you can also run two vLLM workers on tiny models (cap --gpu-memory-utilization
around 0.4 each) and turn on --router-mode kv to watch cache-aware routing pick the warm
worker on repeated prefixes — the single most instructive experiment in the repo.
Tier 3 — runtime-only Rust, near-zero deps. lib/runtime/examples/hello_world runs two
processes against file/memory discovery; lib/llm has in-process engine examples
(echo_full) reachable via dynamo-run-style entrypoints (lib/llm/src/entrypoint/), and
cargo run -p dynamo-llm --bin generate-frontend-openapi exercises the HTTP surface without
any worker at all.
Things that genuinely need a cluster (read, don’t run): NIXL GPU-to-GPU transfer, KVBM multi-tier offload under pressure, the K8s operator/DGDR path, planner scaling against real FPM streams, JetStream-durable KV events with multi-replica routers.
llm-d
A study guide to the llm-d/llm-d repository (checked out at materials/llm-d/llm-d, tracking main around the v0.7 release, 2026-05).
1. What it is
llm-d is a Kubernetes-native distributed inference serving stack: an orchestration and routing layer that sits above model servers (vLLM, SGLang) and below your clients, built out of the Kubernetes Gateway API, Envoy ext-proc, and engine-level telemetry. From README.md:
llm-d is a high-performance distributed inference serving stack optimized for production deployments on Kubernetes. We help you achieve the fastest “time to state-of-the-art (SOTA) performance” for key OSS large language models across most hardware accelerators…
llm-d is a Cloud Native Computing Foundation (CNCF) sandbox project, founded by Red Hat, Google Cloud, IBM Research, CoreWeave, and NVIDIA.
Positioning, in the project’s own words (README.md): “Model servers like vLLM and SGLang handle efficiently running large language models on accelerators. llm-d provides state-of-the-art orchestration and optimizations above model servers to serve high-scale real-world traffic efficiently and reliably.”
Two principles from PROJECT.md explain its shape better than any diagram:
- We respect our upstreams - vLLM and the Kubernetes Inference Gateway are where code changes start, no forks
- vLLM-first but not vLLM-only - build the modular architecture for most people and collaborate with other projects
So llm-d is not a monolith — it is a composition of the Gateway API Inference Extension (GAIE), vLLM, and a set of llm-d-org components, wired together by the recipes in this repo. Project maintainers (per PROJECT.md) are Carlos Costa, Clayton Coleman, and Robert Shaw — representing inference research, the llm-d Router, and vLLM respectively.
2. Why you care
This project is the intersection of what you already know and what you’re learning:
- The data plane is Envoy. The “llm-d Router” is literally Envoy (or any conformant L7 proxy: Istio, agentgateway, GKE ALB) plus an ext-proc gRPC service called the Endpoint Picker (EPP). The repo ships a complete, readable static Envoy config (
guides/no-kubernetes-deployment/router/envoy/envoy.yaml) — ext_proc filter inFULL_DUPLEX_STREAMEDmode, anORIGINAL_DSTcluster keyed off a routing header, circuit breakers, gRPC health checks on the EPP. It will look like home. - The control plane is Gateway API.
Gateway+HTTPRoute+ a new backend kind,InferencePool, from the GAIE project you’re studying separately. llm-d is GAIE’s most prominent downstream consumer and contributor. - The novel part is the scoring signal. Instead of least-request/ring-hash on connection counts, endpoints are scored on KV-cache utilization, prefix-cache locality (approximate radix-style hashing or precise event-driven indexing), queue depth, and predicted latency — i.e., load balancing where the “load” is HBM contents. This is cache-aware routing and P/D disaggregation, productionized.
- The benchmark argument is a traffic argument.
guides/optimized-baseline/README.mdcompares the EPP against a stock Kubernetes Service round-robining the same 8 vLLM pods: output tokens/sec 5,722 → 13,163 (+130%), TTFT p90 107.43s → 0.206s at high rates. The headline gains of this whole space come from routing, not kernels.
3. What’s actually in this repo (and what isn’t)
This is the umbrella/docs/deployment repo of the llm-d GitHub org. There is no service source code here — no Go, no Rust, no Python services. What you deploy comes from sibling repos and OCI registries; what lives here is the documentation, the Helm values + Kustomize overlays that configure those components, and the Dockerfiles for llm-d’s custom vLLM container images.
| Path | What it is |
|---|---|
README.md, PROJECT.md, SIGS.md, CONTRIBUTING.md, MAINTAINERS.md | Project identity, governance, SIG structure (Kubernetes-style OWNERS files throughout) |
docs/architecture/ | The system design docs — the core of this guide. core/ (router, EPP, InferencePool, model servers) and advanced/ (KV management, disaggregation, autoscaling, batch, latency predictor) |
docs/well-lit-paths/ | Concept pages for each supported pattern (the “why”), one per pattern |
docs/api-reference/ | InferencePool, InferenceObjective, InferenceModelRewrite CRDs; EndpointPickerConfig schema; EPP HTTP headers/APIs; glossary |
docs/getting-started/ | quickstart.md and artifacts.md (the authoritative map of charts, images, and source repos) |
docs/infra-providers/ | GKE, AKS, OpenShift, DigitalOcean, minikube notes |
docs/resources/observability/ | Metrics catalog, PromQL cookbook, tracing setup; plus docs/resources/rdma/ for the network prerequisites |
docs/proposals/ | Design proposals (autoscaler, batch gateway, non-Kubernetes mode, distributed tracing, …) |
guides/ | The deployable recipes (“well-lit paths” — the “how”): per-guide Helm values files for the router chart + Kustomize overlays for model servers, per accelerator (NVIDIA/AMD/Intel XPU/Gaudi/TPU/CPU) |
guides/recipes/ | Shared building blocks: base router values, base model-server Deployments, gateway install kustomizations (Istio, agentgateway, kgateway, GKE) |
docker/ | Dockerfile.cuda, .rocm, .cpu, .hpu — build the ghcr.io/llm-d/llm-d-cuda etc. images: vLLM plus the RDMA/P2P stack (UCX, NVSHMEM, NIXL, GDRCopy, DeepEP, LMCache, InfiniStore) |
patches/ | NVSHMEM patches applied in those image builds |
helpers/, scripts/, release/, .github/ | Benchmark harness docs, client setup, lint/CI plumbing |
The actual code, by sibling repo (from the table in docs/getting-started/artifacts.md, names only):
- llm-d-router (Go) — the EPP: routing engine, plugin framework, flow control. Older docs call it llm-d-inference-scheduler; it builds on the GAIE endpoint-picker framework. Most architecture docs here link directly into its
pkg/epp/framework/plugins/...tree. - llm-d-routing-sidecar — the P/D routing proxy sidecar in decode pods (image
ghcr.io/llm-d/llm-d-routing-sidecar). - llm-d-kv-cache (Go/Python/C++) — KV-block locality indexer and filesystem offloading connector.
- llm-d-latency-predictor (Python) — XGBoost training + prediction sidecars for predicted-latency scheduling.
- llm-d-workload-variant-autoscaler (Go) — SLO-aware autoscaler.
- llm-d-batch-gateway, llm-d-async (incubation) — OpenAI Batch API and queue-based async processing.
- llm-d-benchmark (Python) — the harness invoked by every guide’s benchmarking section.
- llm-d-inference-sim (Go) — a GPU-free vLLM simulator (important for you; see section 7).
One subtlety worth knowing (from docs/getting-started/artifacts.md): the Helm charts themselves are currently published by GAIE (oci://registry.k8s.io/gateway-api-inference-extension/charts/standalone in the quickstart; oci://ghcr.io/llm-d/charts/llm-d-router-standalone-dev in the v0.7 guides), with a note that publishing will move to llm-d. The boundary between GAIE and llm-d is deliberately thin.
4. The architecture as documented
4.1 Three core concepts
From docs/architecture/README.md: the llm-d Router (= Proxy + EPP), the InferencePool, and the Model Server.
- Proxy: A high-performance L7 proxy (typically Envoy) that accepts user requests and consults the EPP via the
ext-procprotocol to determine the optimal destination.- Endpoint Picker (EPP): The routing engine that scores and selects model server pods based on real-time metrics, KV-cache affinity, and configured policies.
InferencePool is described as an “LLM-optimized Service” — a label-selector grouping of pods serving one base model, with Variants (sub-groupings via pod labels, e.g. prefill vs decode roles, expressed as llm-d.ai/role: prefill|decode|prefill-decode).
Terminology is pinned down in docs/architecture/core/router/README.md and worth internalizing because older blog posts disagree: llm-d Router = Proxy + EPP (the whole entry point); Inference Gateway = the Router when operating in Gateway Mode; Request Scheduler = the decision engine inside the EPP. The EPP also carries dual responsibilities — routing and “fairness and prioritization”, i.e. which requests run at all when consolidating multi-tenant workloads onto shared model servers.
4.2 Request flow: gateway → EPP → vLLM pod
From docs/architecture/core/router/README.md:
When an inference request arrives at the Proxy, the Proxy “parks” the request and initiates a callback to the EPP via the
ext-proc(External Processing) protocol. The EPP evaluates the request against the current state of the InferencePool—considering factors like KV-cache locality, current load, and priority—and returns the address of the optimal model server pod back to the Proxy.
Inside the EPP, the lifecycle is enumerated in docs/architecture/core/router/epp/README.md:
- Request arrival at the proxy (Gateway).
- External processing — proxy invokes the EPP via ext-proc, passing headers and body.
- Request handling — parses the request (OpenAI HTTP, vLLM gRPC; parser is a plugin, so custom protocols slot in) into the internal
InferenceRequest. - Flow control — if enabled, queues, prioritizes, and “holds requests when the pool is saturated”.
- Request scheduling — Filter → Score → Pick against the InferencePool.
- Request proxying — EPP returns the chosen endpoint address; proxy forwards.
Asynchronously, a Data Layer watches the Kube API for pool membership, scrapes model-server metrics, and maintains in-memory state such as the prefix-cache tree; “consultant” sidecars (latency predictor, KV indexer, tokenizer) plug in here. The same doc warns that the only supported ext-proc body mode is FULL_DUPLEX_STREAMED.
Flow control deserves a closer look from a traffic engineer (docs/architecture/core/router/epp/README.md, deep dive in flow-control.md): saturation-gated admission via pluggable SaturationDetectors (e.g. a concurrency detector on per-endpoint in-flight counts), priority bands separating latency-sensitive chat from background batch, and two pluggable fairness layers — FairnessPolicy distributing dispatch opportunities among flows within a band (e.g. round robin) and OrderingPolicy ordering requests within a flow (FIFO, SLO-based). It is an application-level admission/queuing tier of the kind you’d otherwise build with Envoy adaptive concurrency plus priority queues — but keyed on inference signals.
The scheduler (docs/architecture/core/router/epp/scheduling.md) is a weighted-scorer framework — recognizably the GAIE scheduling framework with llm-d’s plugin set. Scorers include kv-cache-utilization-scorer, queue-depth-scorer, prefix-scorer, lora-affinity-scorer, latency-scorer, session-affinity-scorer, and no-hit-lru-scorer (spreads cold prefills across the pool); pickers are max-score-picker, random-picker, weighted-random-picker. Where KV-cache-aware decisions happen: scoring, fed by the data layer.
The default “optimized baseline” policy is just YAML, shipped in this repo as Helm values (guides/optimized-baseline/router/optimized-baseline.values.yaml):
pluginsCustomConfig:
optimized-baseline-plugins.yaml: |
apiVersion: llm-d.ai/v1alpha1
kind: EndpointPickerConfig
plugins:
- type: queue-scorer
- type: kv-cache-utilization-scorer
- type: prefix-cache-scorer
- type: no-hit-lru-scorer
schedulingProfiles:
- name: default
plugins:
- pluginRef: queue-scorer
weight: 2
- pluginRef: kv-cache-utilization-scorer
weight: 2
- pluginRef: prefix-cache-scorer
weight: 3
- pluginRef: no-hit-lru-scorer
weight: 2
4.3 The Envoy wiring (read this file first)
guides/no-kubernetes-deployment/router/envoy/envoy.yaml is the entire data plane in one static config. The selected endpoint is conveyed via a header consumed by an ORIGINAL_DST cluster:
clusters:
- name: original_destination_cluster
type: ORIGINAL_DST
connect_timeout: 1000s
lb_policy: CLUSTER_PROVIDED
original_dst_lb_config:
use_http_header: true
http_header_name: x-gateway-destination-endpoint
and the ext_proc filter that calls the EPP:
- name: envoy.filters.http.ext_proc
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.ext_proc.v3.ExternalProcessor
grpc_service:
envoy_grpc:
cluster_name: ext_proc
authority: localhost:9002
timeout: 10s
processing_mode:
request_body_mode: FULL_DUPLEX_STREAMED
response_body_mode: FULL_DUPLEX_STREAMED
That is the GAIE Endpoint Picker Protocol made concrete: EPP picks a pod IP, writes it into x-gateway-destination-endpoint, Envoy’s original-dst LB forwards to it. Response bodies also stream back through the EPP for post-processing (metrics, prefix-index updates).
4.4 Two deployment modes for the proxy
docs/architecture/core/router/proxy.md defines:
- Standalone Mode — Envoy runs as a sidecar in the EPP pod, ext-proc over localhost, no Gateway/HTTPRoute/controller needed. For testing, batch, RL pipelines, legacy-Ingress clusters.
- Gateway Mode (= “Inference Gateway”) — the GAIE integration: a shared
GatewayhostsHTTPRoutes whosebackendRefsareInferencePools (group: inference.networking.k8s.io), alongside ordinaryServicebackends. Used for shared infra, multi-cluster LB, traffic splitting/canary. Supported providers documented: Istio, GKE Gateway, agentgateway (kgateway deprecated as of v0.7).
4.5 The well-lit paths
Indexed in docs/well-lit-paths/README.md and guides/README.md, grouped as: Intelligent Routing (optimized baseline; predicted-latency routing), Advanced KV-Cache Management (precise prefix-cache routing; tiered prefix cache offload to CPU/NVMe), Serving Large Models (P/D disaggregation; wide expert-parallelism), Operational Excellence (flow control; workload autoscaling; rollouts), Workloads (agentic inference; multimodal), and Experimental (async processing; batch gateway; no-Kubernetes deployment).
Prefix-cache-aware routing (docs/architecture/advanced/kv-management/prefix-cache-aware-routing.md) has two implementations:
| Feature | Approximate | Precise |
|---|---|---|
| Precision | Heuristic (character-based block hashing) | 100% (token-based) |
| State source | Local EPP assumptions after each routing decision | Real-time KVEvents from model servers over ZMQ |
| Dependencies | None | vLLM /v1/completions/render tokenizer endpoint, ZMQ |
The approximate path splits prompts into fixed-size blocks, keeps a rolling-hash LRU index of which prefixes were routed where, and “learns” from its own decisions. The precise path subscribes to vLLM KV-cache block add/evict events and maintains a global KV-block index (the llm-d-kv-cache component), with speculative indexing to cover the decision-to-event blind spot. guides/precise-prefix-cache-routing/router/precise-prefix-cache-routing.values.yaml shows the full production config including active-active HA EPP replicas.
P/D disaggregation (docs/architecture/advanced/disaggregation/README.md) — the EPP’s disagg-profile-handler runs a decode profile, asks a decider plugin whether the uncached suffix on the chosen decode pod is large enough to justify disaggregation, and only then runs the prefill profile. Decode endpoint becomes the proxy’s primary destination; the prefill endpoint rides along in the x-prefiller-host-port header. The sequence, from that doc:
sequenceDiagram
Client->>Proxy: Request
Proxy-->>EPP: Run EPP protocol
EPP-->>Proxy: Selects P Worker and D Worker
Proxy->>DSidecar: Request
DSidecar->>PWorker: Request with max_tokens=1, do_remote_decode=True
PWorker->>DSidecar: Response with KVTransferParams
DSidecar->>DWorker: Request with KVTransferParams and do_remote_prefill=True
DWorker-->>PWorker: Pull KV Cache (NIXL RDMA)
KV transfer uses NIXL (NVIDIA’s transfer library, from the ai-dynamo org — shared infrastructure with Dynamo) over UCX/UCCL/libfabric on IB/RoCE/EFA; TCP fallback exists “for testing and development” only. The vLLM protocol (nixlv2) is two-phase sequential; SGLang’s is concurrent with out-of-band bootstrap coordination. The corresponding EPP config is in guides/pd-disaggregation/router/pd-disaggregation.values.yaml — two schedulingProfiles named prefill and decode, each with its own prefill-filter/decode-filter plus prefix/queue/kv-utilization scorers, composed by disagg-profile-handler. The reference deployment is openai/gpt-oss-120b with 8 TP=1 prefill instances and 2 TP=4 decode instances (“heterogeneous parallelism”; tune your xPyD ratio to your ISL/OSL).
Wide expert-parallelism (guides/wide-ep-lws/README.md) — DeepSeek-R1 at DP=16 prefill + DP=16 decode over 32 H200/B200 GPUs, deployed with the LeaderWorkerSet (LWS) controller, requiring full-mesh all-to-all RDMA for DeepEP (“rail-only connectivity will fail”). This is the path the custom ghcr.io/llm-d/llm-d-cuda images (built from docker/Dockerfile.cuda with NVSHMEM/DeepEP patches from patches/) exist for.
4.6 Role of the Gateway API Inference Extension
Cross-reference for your gateway-api-inference-extension study: llm-d consumes GAIE at three layers, all visible in this repo.
- CRDs: every guide starts with
kubectl apply -f .../gateway-api-inference-extension/releases/download/v1.5.0/v1-manifests.yaml(InferencePoolisinference.networking.k8s.io/v1;InferenceObjectiveandInferenceModelRewritearellm-d.ai/v1alpha2— seedocs/api-reference/README.md). - Charts: the standalone/gateway router Helm charts are published from GAIE’s
config/charts/today (docs/getting-started/artifacts.md). - Framework: the EPP in llm-d-router is built on GAIE’s endpoint-picker framework and Endpoint Picking Protocol; llm-d adds the plugin set (precise prefix cache, disagg handler, latency predictor, flow control policies) and the production recipes.
PROJECT.md’s “no forks” principle is the governing relationship.
In short: GAIE is the routing framework and API; llm-d is the opinionated, benchmarked distribution of it, plus the engine-side pieces (sidecar, KV indexer, images) GAIE doesn’t own.
4.7 Deployment shape: what a minimal install looks like
From docs/getting-started/quickstart.md, an install is exactly three commands after CRDs:
kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/download/${GAIE_VERSION}/v1-manifests.yaml
helm install ${GUIDE_NAME} \
oci://registry.k8s.io/gateway-api-inference-extension/charts/standalone \
-f guides/recipes/router/base.values.yaml \
-f guides/optimized-baseline/router/optimized-baseline.values.yaml \
-n ${NAMESPACE} --version ${GAIE_VERSION}
kubectl apply -n ${NAMESPACE} -k guides/optimized-baseline/modelserver/gpu/vllm/
Pattern: Helm for the router (chart from GAIE/llm-d registries, configured by layered values files from this repo) + Kustomize for the model servers (base Deployment in guides/recipes/modelserver/base/single-host/default/decode-deployment.yaml — a plain Deployment labeled llm-d.ai/role: decode with /health + /v1/models probes — patched per guide/accelerator, e.g. guides/optimized-baseline/modelserver/gpu/vllm/base/patch-vllm.yaml setting vllm serve Qwen/Qwen3-32B --tensor-parallel-size=2 --gpu-memory-utilization=0.95, 8 replicas). The EPP discovers pods via the InferencePool’s matchLabels (llm-d.ai/guide: "optimized-baseline"). Worth reading in guides/recipes/router/base.values.yaml: the Envoy sidecar args block, with an inline comment explaining why --log-level warn and --concurrency 8 override the chart defaults (trace logging and hardware_concurrency() worker threads oversubscribing the cgroup CPU slice — your kind of footnote).
5. Suggested reading path
README.md, thenPROJECT.md— what it is, who runs it, the no-forks upstream policy.docs/architecture/README.md→docs/architecture/core/router/README.md→proxy.md— Router = Proxy + EPP, standalone vs gateway modes.guides/no-kubernetes-deployment/router/envoy/envoy.yaml+router/epp/config.yaml— the whole data plane and a completeEndpointPickerConfigin two files, no cluster abstractions in the way.docs/architecture/core/router/epp/README.md→scheduling.md→flow-control.md→datalayer.md→configuration.md— the EPP internals; thendocs/api-reference/endpointpickerconfig.mdandepp-http-headers.mdas reference.docs/architecture/core/inferencepool.md+docs/api-reference/inferencepool.md— the CRD bridge to GAIE.guides/optimized-baseline/README.md(skim manifests, read the benchmark report at the bottom) — the canonical deployment and the RR-vs-EPP numbers.docs/architecture/advanced/kv-management/(all four files) andguides/precise-prefix-cache-routing/— approximate vs precise cache-aware routing.docs/architecture/advanced/disaggregation/README.md+guides/pd-disaggregation/README.md+guides/recipes/modelserver/base/single-host/pd/vllm/patch-sidecar.yaml— P/D end to end, including the sidecar.guides/wide-ep-lws/README.md+docker/Dockerfile.cuda— wide-EP and what actually goes into the engine image.docs/resources/observability/— the operational surface:metrics.mdfor enabling scraping,promql.mdfor ready-made queries over the EPP scheduler metrics (inference_pool_per_pod_queue_size,inference_extension_prefix_indexer_hit_ratio, per-plugin latency distributions — cataloged in the tables at the bottom ofdocs/architecture/core/router/epp/scheduling.md) plus vLLM metrics, andtracing.mdfor OTel setup.- Skim
docs/well-lit-paths/for the rest (flow control, autoscaling, batch, agentic), anddocs/proposals/if you want to see where the project is heading (non-kubernetes-mode.md,distributed-tracing.md,autoscaler.md).
6. Connections to your other study repos
- gateway-api-inference-extension — the routing brain llm-d builds on. Everything in section 4.6; read GAIE’s Endpoint Picker Protocol proposal alongside
docs/architecture/core/router/epp/README.mdand the Envoy config above. llm-d is effectively GAIE’s flagship consumer plus an extended plugin catalog (the EPP code itself lives in sibling repo llm-d-router). - vllm — the engine being orchestrated. llm-d leans on vLLM features you can study in that repo: automatic prefix caching (what the prefix scorers exploit), KV-cache events over ZMQ (what precise routing consumes), the KV-connector/NIXL interface and
kv_transfer_params(what P/D rides on),/v1/completions/rendertokenization, data-parallel deployment (one pod, multiple endpoints), and the OffloadingConnector (tiered prefix cache). - dynamo — the closest competitor, from NVIDIA, and the sharpest contrast: Dynamo is a self-contained distributed runtime (its own Rust router/frontend, etcd/NATS control plane, planner) that runs on or off Kubernetes; llm-d is deliberately not a runtime — it reuses Kubernetes primitives (Gateway API, Deployments, LWS, HPA) and Envoy, adding only the EPP and sidecars. They share NIXL for KV transfer (note
NIXL_REPO=github.com/ai-dynamo/nixlindocker/Dockerfile.cuda). Compare Dynamo’s KV-aware router with the EPP’s scorer pipeline, and Dynamo’s planner with llm-d’s Workload Variant Autoscaler. - sglang — the second engine: first-class in the optimized-baseline and P/D guides (
guides/optimized-baseline/modelserver/gpu/sglang/), with its own concurrent bootstrap-room KV-transfer protocol documented indocs/architecture/advanced/disaggregation/README.md— contrast with vLLM’s sequentialnixlv2. - nano-vllm — a minimal engine is the right mental model for what an llm-d “endpoint” is: the EPP only needs an OpenAI-compatible HTTP surface plus standard metrics; everything llm-d scores (queue depth, KV utilization, prefix reuse) maps to structures you can see in nano-vllm’s scheduler and block manager in a few hundred lines.
- xgrammar / flashinfer — below llm-d’s abstraction line. They live inside the engine pods (structured-output constraints and attention/sampling kernels respectively); llm-d never sees them except as their effects on per-request latency and throughput — exactly the signals the latency predictor learns. flashinfer also illustrates why decode is memory-bandwidth-bound, which is the entire premise of P/D specialization.
7. Hands-on without a 16-GPU cluster
Honest assessment: the headline paths are out of home reach — the optimized baseline wants 16 GPUs (8×TP=2 for Qwen3-32B), P/D wants RDMA between nodes, wide-EP wants 32 H200s. With one RTX 5080 (16 GB) on Windows, treat this as a docs-and-configs repo plus three realistic exercises:
- No-Kubernetes deployment, scaled down (best option).
guides/no-kubernetes-deployment/README.mdruns the real stack — EPP container + Envoy container + vLLM worker(s) — with Docker only, endpoints declared in a YAML file (file-discoveryplugin, hot-reloaded via atomic rename). Under WSL2 with CUDA Docker, substitute the 32B model with something that fits 16 GB (e.g. a 4–8B model at--tensor-parallel-size=1) and drop the TP/shm settings; the EPP and Envoy configs are model-agnostic. You can then watch scheduling happen: send repeated shared-prefix prompts, scrape EPP metrics on:9090(inference_extension_prefix_indexer_hit_ratio, per-pod queue gauges), and poke Envoy admin on:19000to see the ext_proc cluster and original-dst routing. - GPU-free routing experiments with the simulator.
docs/getting-started/artifacts.mdlists sibling repo llm-d-inference-sim, a “GPU-free vLLM simulator.” Point the file-discoveryendpoints.yaml(or a kind/minikube InferencePool) at several simulator instances and you can exercise the full EPP plugin pipeline — scorer weights, flow-control priority bands, even the disagg profile handler — with zero accelerators. For a traffic engineer, this is the highest-signal-per-watt exercise in the project. - kind/minikube for the control plane only. The GAIE CRDs, router Helm chart, Gateway+HTTPRoute wiring (
guides/recipes/gateway/), and EPP all install on a CPU-only cluster; only the model-server Kustomize step needs GPUs (swap in CPU vLLM viaguides/optimized-baseline/modelserver/cpu/vllm/— though it wants 64 cores/replica — or the simulator). Notedocs/infra-providers/minikube/README.mdis currently a stub (“TBD”), so expect to adapt the quickstart yourself.
Pure reading also pays here more than in most repos: the architecture docs are recent (v0.7), unusually candid about trade-offs (approximate-vs-precise tables, P/D “not a target for all workloads” guidance, NIXL TCP-is-dev-only warnings), and every claim is tied to a manifest you can open in the same checkout.
Gateway API Inference Extension
Kubernetes SIG project (kubernetes-sigs/gateway-api-inference-extension, “GIE” or “Inference Gateway”) that turns any ext-proc-capable, Gateway API-conformant proxy — Envoy Gateway, Istio, kgateway, GKE Gateway, agentgateway — into an inference-aware L7 load balancer. It ships two things: a set of CRDs (InferencePool and friends) that model a fleet of model-server pods as a routable backend, and the Endpoint Picker (EPP), a Go gRPC server implementing Envoy’s External Processing protocol that picks the destination pod per request using live vLLM/SGLang/TensorRT-LLM metrics.
Repo status note (important for contribution targeting). Per
README.md:16-23, the EPP,InferenceObjective/InferenceModelRewriteAPIs, and the Body-Based Router are migrating tollm-d/llm-d-inference-schedulerandllm-d/llm-d-inference-payload-processor; this repo remains the home of the InferencePool API, the lightweight EPP (lwepp), and the conformance suite. This checkout still contains the complete EPP (pkg/epp/), so it is the right place to study the architecture — but new scheduler/EPP PRs should go to the llm-d org. Everything below cites code in this repo.
1. What it is, and what classic LB cannot do
Classic load balancing assumes requests are roughly interchangeable and backends are roughly stateless: round-robin, least-request, EWMA, maglev/ring-hash for affinity. LLM serving violates every assumption:
- Per-request cost variance is enormous. A request is not a unit of work; its cost is
O(prompt_tokens)for prefill plusO(output_tokens × batch_interference)for decode. Two in-flight requests can differ by 1000x in GPU-seconds. Least-request counts requests, not tokens. - Backends are stateful. Each replica holds a KV cache (prefix cache). Routing a request to a pod that already has its prompt prefix cached skips most of prefill — but no L7 hash policy can know which pod has which prefix blocks resident, because residency changes with every admission and eviction.
- Latency is load-dependent and saturation is a cliff. Once a vLLM replica’s KV blocks are exhausted it starts queueing or preempting (recompute), and decode latency for everyone in the batch degrades. You want load-aware routing on queue depth and KV utilization, and admission control before the cliff, not after.
- LoRA multiplexing. A pod can only hold N adapters in GPU memory; routing a request for adapter X to a pod that must first swap X in costs hundreds of ms.
The project’s answer: leave the data plane in Envoy, and put a per-request scheduling decision into an external processor that watches model-server metrics at 50ms granularity. The EPP filters/scores/picks among the pods of an InferencePool and tells Envoy where to send the request via a header consumed by an ORIGINAL_DST cluster. Selection criteria (from README.md:71-81): KV-cache pressure, queue depth, prefix-cache affinity, LoRA adapter residency, request priority — with saturation-aware shedding of low-priority traffic.
2. Why you care
This is precisely the system a traffic-infra engineer sketches from first principles after learning how continuous batching works: “I need least-loaded routing where load = queue depth + KV pressure, plus consistent-hash-like affinity on prompt prefixes, plus load shedding before the saturation cliff — and I want it in my existing Envoy fleet, not a bespoke proxy.”
- The EPP is an Envoy ext-proc server (
envoy.service.ext_proc.v3.ExternalProcessor) speaking full-duplex streamed gRPC. The protocol, processing modes, header mutations, dynamic metadata,ClearRouteCache,ImmediateResponsefor 429/503 — all the machinery is the Envoy machinery you already know. - The scheduling layer is a clean filter → scorer → picker plugin pipeline (deliberately modeled on kube-scheduler), so the inference-specific knowledge is encapsulated in ~100-line plugins with explicit scoring math. You can read the entire decision logic in an afternoon.
- It is the vendor-neutral standard for this layer: Envoy Gateway, Istio, kgateway, GKE, and agentgateway all implement
InferencePoolconformance; llm-d (Google/Red Hat/IBM + vLLM collaboration) builds its inference scheduler directly on this framework. Contributions here (or inllm-d/llm-d-inference-scheduler, where the EPP now lives) land in every implementation. - It’s GA (v1
InferencePool), but the interesting frontiers are wide open: flow control / fairness (pkg/epp/flowcontrol/is explicitly experimental), P/D-disaggregation profiles, latency-SLO prediction-based scheduling, multi-clusterInferencePoolImport. Envoy + Go + queueing theory is exactly your toolkit.
3. Architecture map
client ──HTTP──> Gateway (Envoy) [data plane]
│ ext_proc filter (FULL_DUPLEX_STREAMED gRPC)
▼
EPP (this repo) [decision plane]
├─ handlers/ ext-proc protocol state machine
├─ requestcontrol Director: parse → admit → schedule → mutate
├─ flowcontrol/ priority queues, fairness, saturation
├─ scheduling/ filters → scorers → picker (plugins)
├─ datalayer/ per-pod collectors ──/metrics──> vLLM pods
└─ controller/ watches InferencePool/Pods/Objectives
│
└──returns x-gateway-destination-endpoint = ip:port
▼
Envoy ORIGINAL_DST cluster ──> chosen model-server pod
The APIs
InferencePool(v1,api/v1/inferencepool_types.go:32) — the backend type you put in anHTTPRoute.backendRefs. Spec is just three things: a labelselectorfor member pods,targetPorts(1-8 ports; eachip:portis a distinct endpoint, used for data-parallel ranks), andendpointPickerRefpointing at the EPP Service.endpointPickerRef.failureMode(api/v1/inferencepool_types.go:168-189) isFailCloseby default — exactly Envoy ext-proc’sfailure_mode_allowsemantics surfaced as API. Status carries per-parentAccepted/ResolvedRefsconditions written by each Gateway controller.InferenceObjective(v1alpha2,apix/v1alpha2/inferenceobjective_types.go:60-73) — attaches a workload identity and an integerpriority(default 0;priority < 0= sheddable, seepkg/epp/util/request/sheddable.go:20-22) to requests, selected via thex-gateway-inference-objectiveheader.InferenceModelRewrite(v1alpha2) — maps client-facing model names to backend models/adapters with weighted splits (the EPP rewrites"model"in the JSON body and un-rewrites it in responses,pkg/epp/handlers/server.go:446-460).InferencePoolImport(v1alpha1) — experimental multi-cluster pool export/import.EndpointPickerConfig(apix/config/v1alpha1/endpointpickerconfig_types.go:33) — not a CRD but a config-file schema (mounted ConfigMap) declaring which plugins to instantiate, scheduling profiles with per-scorer weights, flow control, saturation detector, and parser. This file is the EPP’s “scheduler policy” surface.
The EPP process
cmd/epp/main.go → cmd/epp/runner/runner.go wires everything: a controller-runtime manager watching the pool/pods, the datalayer runtime (one collector goroutine per endpoint), the plugin registry (cmd/epp/runner/runner.go:459-509 registers every in-tree plugin factory), the scheduler, optional flow control, and the gRPC ext-proc server (default port 9002, gRPC health on 9003, Prometheus on 9090 — pkg/epp/server/options.go:93-118). The default plugin config when you supply none is queue + KV-cache + prefix scorers (pkg/epp/config/loader/defaults.go:46-103).
cmd/lwepp/ is the lightweight EPP that stays in this repo: same ext-proc protocol and InferencePool discovery, but a trivial round-robin picker (pkg/lwepp/handlers/server.go:84-89) and no metrics scraping — a minimal reference data plane for conformance and for gateways that want endpoint subsetting without smart scheduling (it recently gained port-aware filtering for data-parallel ranks).
4. Core mechanisms
4.1 The ext-proc flow
Where it sits in Envoy: the canonical config used by e2e tests is test/testdata/envoy.yaml. The ext_proc HTTP filter runs before the router with all body modes set to full-duplex streaming:
# test/testdata/envoy.yaml:97-109
- name: envoy.filters.http.ext_proc
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.ext_proc.v3.ExternalProcessor
grpc_service:
envoy_grpc:
cluster_name: ext_proc
authority: vllm-qwen3-32b-epp.$E2E_NS:9002
timeout: 10s
processing_mode:
request_header_mode: SEND
response_header_mode: SEND
request_body_mode: FULL_DUPLEX_STREAMED
response_body_mode: FULL_DUPLEX_STREAMED
and the route targets an ORIGINAL_DST cluster that takes its destination from the EPP-set header — this is the whole trick that lets an external process steer per-request endpoint selection without EDS churn:
# test/testdata/envoy.yaml:151-162
- name: original_destination_cluster
type: ORIGINAL_DST
connect_timeout: 1000s
lb_policy: CLUSTER_PROVIDED
circuit_breakers:
thresholds:
- max_connections: 40000
max_pending_requests: 40000
max_requests: 40000
original_dst_lb_config:
use_http_header: true
http_header_name: x-gateway-destination-endpoint
Per request, Envoy streams to the EPP: RequestHeaders → N× RequestBody chunks (EoS on last) → later ResponseHeaders → N× ResponseBody chunks. The EPP’s Process() loop (pkg/epp/handlers/server.go:162) runs a per-stream state machine (StreamRequestState, pkg/epp/handlers/server.go:142-154) with a single reader goroutine and a select that can also fire an eviction channel — flow control can yank an already-queued request mid-stream by sending ImmediateResponse(429) (pkg/epp/handlers/server.go:467-480).
The scheduling decision happens at request-body-EoS (it needs the parsed body for model name and prefix hashing): pkg/epp/handlers/server.go:319-348 parses the body, calls director.HandleRequest, then builds the response Envoy applies. The decision is communicated both as a header mutation and as dynamic metadata (namespace envoy.lb, key x-gateway-destination-endpoint — pkg/epp/metadata/consts.go:26-28), because gateway integrations differ in which they consume:
// pkg/epp/handlers/request.go:87-99
return &extProcPb.ProcessingResponse{
Response: &extProcPb.ProcessingResponse_RequestHeaders{
RequestHeaders: &extProcPb.HeadersResponse{
Response: &extProcPb.CommonResponse{
ClearRouteCache: true,
HeaderMutation: &extProcPb.HeaderMutation{
SetHeaders: s.generateHeaders(ctx, reqCtx),
},
},
},
},
DynamicMetadata: dynamicMetadata,
}
The full EPP↔proxy contract is specified in docs/proposals/004-endpoint-picker-protocol/README.md: the proxy may pass a candidate subset via filter metadata envoy.lb.subset_hint / x-gateway-destination-endpoint-subset; the EPP may return a comma-separated fallback list of endpoints (pkg/epp/requestcontrol/director.go:329-333 joins multiple picks) that Envoy walks on retry; 503 if no ready endpoint, 429 to shed. Headerless requests (e.g. GET /health) get a random endpoint at header time (pkg/epp/handlers/request.go:41-54). Response bodies are streamed back through the EPP so it can observe usage/latency per chunk and run post-response plugins — the EPP is on the data path for the whole response, which is why the protocol insists on full-duplex streaming.
The Director orchestration sequence (pkg/epp/requestcontrol/director.go:151-229): model rewrite → resolve InferenceObjective/priority → admission control → locate candidates (subset hint or full pool) → run PrepareData plugins (e.g. prefix matching) → run admission plugins → schedule → set TargetEndpoint + run PreRequest plugins.
4.2 The scheduling framework
pkg/epp/scheduling/scheduler.go:54 runs a loop of profiles chosen by a ProfileHandler (multi-profile exists for P/D disaggregation: a “prefill” profile and a decode profile can each pick an endpoint per request; see the hardcoded experimentalDefaultPrefillProfile = "prefill" in pkg/epp/framework/plugins/requestcontrol/dataproducer/approximateprefix/types.go:77-86). Each SchedulerProfile.Run is the kube-scheduler pattern:
// pkg/epp/scheduling/scheduler_profile.go:117-128
func (p *SchedulerProfile) Run(ctx context.Context, request *fwksched.InferenceRequest, cycleState *fwksched.CycleState, candidateEndpoints []fwksched.Endpoint) (*fwksched.ProfileRunResult, error) {
endpoints := p.runFilterPlugins(ctx, request, cycleState, candidateEndpoints)
if len(endpoints) == 0 {
return nil, errcommmon.Error{Code: errcommmon.Internal, Msg: "no endpoints available for the given request"}
}
// if we got here, there is at least one endpoint to score
weightedScorePerEndpoint := p.runScorerPlugins(ctx, request, cycleState, endpoints)
result := p.runPickerPlugin(ctx, cycleState, weightedScorePerEndpoint)
return result, nil
}
Scorers return map[endpoint]float64 clamped to [0,1]; the profile accumulates score × weight (pkg/epp/scheduling/scheduler_profile.go:165-168). The default profile (pkg/epp/config/loader/defaults.go:47-49 and helm config/charts/epplib/templates/_config.yaml:77-84) is:
| plugin | weight |
|---|---|
queue-scorer | 2 |
kv-cache-utilization-scorer | 2 |
prefix-cache-scorer | 3 |
max-score-picker (injected by default) | — |
The actual plugin inventory (registered in cmd/epp/runner/runner.go:459-509):
Scorers (pkg/epp/framework/plugins/scheduling/scorer/):
queue-scorer— min-max normalization of vLLM’s waiting-queue length across candidates:
// pkg/epp/framework/plugins/scheduling/scorer/queuedepth/queue.go:93-100
// endpointScoreFunc calculates the score based on the queue size of each endpoint. Longer queue gets a lower score.
endpointScoreFunc := func(endpoint framework.Endpoint) float64 {
if maxQueueSize == minQueueSize {
// If all pods have the same queue size, return a neutral score
return 1.0
}
return float64(maxQueueSize-endpoint.GetMetrics().WaitingQueueSize) / float64(maxQueueSize-minQueueSize)
}
kv-cache-utilization-scorer— simply1 - KVCacheUsagePercent(pkg/epp/framework/plugins/scheduling/scorer/kvcacheutilization/kvcache_utilization.go:76-82). Absolute, not normalized: an empty fleet scores all 1.0 and the queue scorer breaks the tie.prefix-cache-scorer—matchedBlocks / totalBlocksof the prompt found in the per-pod prefix index (pkg/epp/framework/plugins/scheduling/scorer/prefix/plugin.go:108-111); see below for how the index is built.lora-affinity-scorer— tiered constants from adapter residency metrics:
// pkg/epp/framework/plugins/scheduling/scorer/loraaffinity/lora_affinity.go:85-98
switch {
// Ideal: The adapter is already active on this model server.
case active:
scores[endpoint] = 1.0
// Good: The model server has capacity to load at least one more adapter.
case len(endpoint.GetMetrics().ActiveModels)+len(endpoint.GetMetrics().WaitingModels) < endpoint.GetMetrics().MaxActiveModels:
scores[endpoint] = 0.8
// Moderate: The adapter is already in the queue to be loaded on this model server.
case waiting:
scores[endpoint] = 0.6
// Unsuitable: The model server has reached its maximum capacity and cannot load the adapter.
default:
scores[endpoint] = 0.0
}
running-requests-scorer— min-max onRunningRequestsSize(.../runningrequests/runningrequest.go:77-107), i.e. least-request but on the server’s own gauge.token-load-scorer—1 - inFlightTokens/thresholdusing EPP-side in-flight token accounting (.../tokenload/token_load.go:82-105), a leading indicator that doesn’t wait for the 50ms scrape.latency-scorer(+slo-headroom-tier-filter,latency-slo-admitter,predicted-latency-producer) — experimental SLO-driven scheduling against a sidecar latency-prediction service (latencypredictor/, Python). Headroom-tiered filtering then scoring by predicted TTFT/TPOT headroom.
Filters: prefix-cache-affinity-filter (pkg/epp/framework/plugins/scheduling/filter/prefixcacheaffinity/plugin.go — keeps only pods whose prefix-match ratio ≥ threshold, used as a two-gate strict/loose pair around the SLO tier filter in the helm config config/charts/epplib/templates/_config.yaml:49-60), slo-headroom-tier-filter, the utilization-detector doubling as a filter (below), and a header-based test filter.
Pickers (pkg/epp/framework/plugins/scheduling/picker/): max-score-picker (shuffle for random tie-break, stable sort desc, take top-N — maxscore/picker.go:87-115), random-picker, weighted-random-picker (score-proportional sampling, used with the latency predictor to avoid thundering-herd on the best pod).
Prefix-cache affinity is the most interesting subsystem (proposal: docs/proposals/0602-prefix-cache-aware-routing-proposal/). The EPP cannot see vLLM’s actual block tables, so it maintains an approximation: an LRU index mapping xxhash block hashes → set of pods that recently served them. The chunking mirrors vLLM’s block hashing, chained like a Merkle list and salted by model name:
// pkg/epp/framework/plugins/requestcontrol/dataproducer/approximateprefix/hashing.go:70-86
h := xxhash.New()
// Different models should have different hashes even with the same body.
_, _ = h.Write([]byte(request.TargetModel))
if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
_, _ = h.Write([]byte(cacheSalt))
}
prevBlockHash := blockHash(h.Sum64())
i := 0
for ; i+cacheBlockSizeChars <= len(userInput); i += cacheBlockSizeChars {
h.Reset()
_, _ = h.Write(userInput[i : i+cacheBlockSizeChars])
_, _ = h.Write(toBytes(prevBlockHash))
res = append(res, blockHash(h.Sum64()))
prevBlockHash = res[len(res)-1]
}
Defaults (approximateprefix/types.go:88-113): block size 16 tokens (vLLM default; auto-tuned from the server’s cache_config_info metric when AutoTune is on), ~4 chars/token heuristic (no tokenizer in the hot path), max 256 blocks matched, LRU capacity 31,250 entries/pod (sized from an H100 KV-budget calculation in the comment). The flow is split across two hooks of the same plugin: PrepareRequestData hashes the prompt and annotates every candidate with PrefixCacheMatchInfo before scheduling; PreRequest records the chosen pod (and the prefill pod, in P/D mode) against those hashes after scheduling, asynchronously (approximateprefix/plugin.go:140-205). The scorer itself just reads the annotation — production of state and consumption are decoupled through the datalayer attribute system (Consumes()/Produces() declarations let the config loader validate the plugin graph).
4.3 Per-pod state: the datalayer
Service discovery: a controller watches pods matching the pool selector (or a static --endpoint-selector in standalone mode) and registers one collector per endpoint. Each collector runs a ticker loop polling its data sources and feeding extractors:
// pkg/epp/datalayer/collector.go:139-158
case <-ticker.Channel():
for _, src := range sources {
tn := src.TypedName()
key := tn.String()
ctx, cancel := context.WithTimeout(c.ctx, defaultCollectionTimeout)
data, err := src.Poll(ctx, endpoint)
cancel()
logErrorTransition(logger, c.lastPollErrors, key, "poll", "source", err)
if err != nil {
continue
}
if srcExtractors, ok := exts[tn.Name]; ok && data != nil {
for _, ext := range srcExtractors {
extErr := ext.Extract(ctx, data, endpoint)
Defaults (pkg/epp/server/options.go:102-109): scrape every 50ms per pod with a 1s timeout, metrics-staleness-threshold 2s, and the vLLM metric names as defaults: vllm:num_requests_waiting, vllm:num_requests_running, vllm:kv_cache_usage_perc, vllm:lora_requests_info, vllm:cache_config_info. The metrics-data-source plugin GETs /metrics and parses Prometheus text; the core-metrics-extractor maps families into the shared Metrics snapshot (pkg/epp/framework/interface/datalayer/metrics.go:26-42: ActiveModels/WaitingModels/MaxActiveModels, RunningRequestsSize, WaitingQueueSize, KVCacheUsagePercent, CacheBlockSize, CacheNumBlocks, UpdateTime). Engine dialects are table-driven — the same extractor handles vLLM, SGLang, trtllm-serve, and Triton via per-engine metric specs selected by a pod label (inference.networking.k8s.io/engine-type):
// pkg/epp/framework/plugins/datalayer/extractor/metrics/factories.go:89-98
{
Name: "sglang",
QueuedRequestsSpec: "sglang:num_queue_reqs",
RunningRequestsSpec: "sglang:num_running_reqs",
KVUsageSpec: "sglang:token_usage",
LoRASpec: "",
CacheInfoSpec: "sglang:cache_config_info",
CacheBlockSizeLabelName: "page_size",
CacheNumBlocksLabelName: "num_pages",
},
The scrape contract a model server must satisfy is the Model Server Protocol (docs/proposals/003-model-server-protocol/README.md). Staleness handling is consumer-side: every snapshot carries UpdateTime, and consumers like the saturation detector treat stale pods as saturated (its own staleness bound defaults to 200ms — pkg/epp/framework/plugins/flowcontrol/saturationdetector/utilization/config.go:36), while readiness/debug listers partition pods into fresh/stale against the 2s threshold (pkg/epp/datalayer/logger/logger.go:78-93).
4.4 Flow control: admission, queueing, saturation
Two admission paths exist behind the AdmissionController interface (pkg/epp/requestcontrol/admission.go:39-56):
Legacy (default) — stateless shed-or-pass: sheddable (priority < 0) requests are rejected with 429 when pool saturation ≥ 1.0; non-sheddable always pass to the scheduler. Saturation comes from the utilization-detector, a roofline model on the two scraped pressure signals:
// pkg/epp/framework/plugins/flowcontrol/saturationdetector/utilization/detector.go:124-136
if metrics == nil || time.Since(metrics.UpdateTime) > d.config.MetricsStalenessThreshold {
totalScore += 1.0
continue
}
qRatio := float64(metrics.WaitingQueueSize) / float64(d.config.QueueDepthThreshold)
kvRatio := metrics.KVCacheUsagePercent / d.config.KVCacheUtilThreshold
// Roofline Analysis: The pod is saturated if either resource is exhausted.
totalScore += max(qRatio, kvRatio)
The same detector doubles as a scheduling filter that drops pods beyond threshold × (1 + headroom) — with a fail-open clause returning all endpoints if everything is saturated (detector.go:142-168).
Flow Control layer (experimental, pkg/epp/flowcontrol/) — a full queueing system: FlowControlAdmissionController.Admit wraps the request and blocks in FlowController.EnqueueAndWait (pkg/epp/flowcontrol/controller/controller.go:206). The controller is a supervisor over sharded ShardProcessor workers; requests land in per-flow queues keyed by x-gateway-inference-fairness-id within priority bands. Pluggable policies, all registered in the same plugin registry: fairness (round-robin, global-strict), ordering (fcfs, edf earliest-deadline-first, slo-deadline), usage limits (static token budgets), and queue data structures (ListQueue, MaxMinHeap — pkg/epp/flowcontrol/framework/plugins/queue/). Outcomes map to ext-proc responses: dispatch → continue to scheduler; reject/TTL-expiry → 429/503; and queued requests can be evicted post-enqueue, which is what that eviction channel in the handler’s select loop is for (pkg/epp/handlers/server.go:260-275). The design doc pkg/epp/flowcontrol/README.md is one of the best-written pieces in the repo — head-of-line blocking, displacement, and band-relative fairness are all spelled out.
4.5 Conformance: what a Gateway implementation must do
conformance/ vendors the upstream Gateway API conformance machinery and defines a Gateway profile (conformance/conformance.go:60-66). To claim InferencePool support an implementation must pass tests covering, among others (conformance/tests/):
inferencepool_accepted,inferencepool_resolvedrefs_condition— status conditions written per parent Gateway.gateway_following_epp_routing(+_dpfor multi-port data parallelism) — deploy a real EPP configured with a header-based test filter, assert the gateway actually routes to the exact pod the EPP names. This is the heart of conformance: the proxy must honorx-gateway-destination-endpoint.epp_unavailable_fail_open— kill the EPP,failureMode: FailOpenpools must still serve.gateway_destination_endpoint_served— the proxy must report which endpoint actually served via response-path metadata.inferencepool_invalid_epp_service,httproute_invalid_inferencepool_ref, port validation, multiple-gateways/pools weighting tests.
Backends in conformance are echo servers, not model servers (conformance/resources/base.yaml:60-83) — conformance tests routing mechanics, not scheduling quality. Reports are published per-implementation under conformance/reports/. If Discord ever fronts inference with its own Envoy control plane, this suite is the compliance bar.
5. Suggested reading path
README.md, thendocs/proposals/002-api-proposal/README.md(why pool+objective) anddocs/proposals/004-endpoint-picker-protocol/README.md(the EPP↔Envoy contract — short, read fully).api/v1/inferencepool_types.go,apix/v1alpha2/inferenceobjective_types.go— the API surface.test/testdata/envoy.yaml— the raw Envoy config; map every field to the protocol doc.pkg/epp/handlers/server.go(Processloop) andpkg/epp/handlers/request.go— the ext-proc state machine.pkg/epp/requestcontrol/director.go:151-229— the orchestration spine.pkg/epp/scheduling/scheduler.go+scheduler_profile.go— the framework; thendocs/proposals/0845-scheduler-architecture-proposal/.- Scorers in order:
queuedepth/queue.go,kvcacheutilization/kvcache_utilization.go,loraaffinity/lora_affinity.go, then the prefix pair:requestcontrol/dataproducer/approximateprefix/{hashing,indexer,plugin}.go+scorer/prefix/plugin.go. pkg/epp/datalayer/collector.go+framework/plugins/datalayer/extractor/metrics/— metrics pipeline;pkg/epp/server/options.gofor every default.pkg/epp/flowcontrol/README.md, thenrequestcontrol/admission.goandsaturationdetector/utilization/detector.go; go deeper intoflowcontrol/controller/only if fairness interests you.pkg/epp/config/loader/defaults.go+apix/config/v1alpha1/endpointpickerconfig_types.go— how config becomes a plugin graph.conformance/tests/gateway_following_epp_routing.goandtest/integration/epp/harness.go— how it’s all verified.
6. Connections to your other study repos
- llm-d — the most direct: llm-d’s inference scheduler is this EPP framework (the code is migrating to
llm-d/llm-d-inference-scheduler), with llm-d adding disaggregated prefill/decode plugins, KV-cache-event-based (precise, not approximate) prefix indexing, and the vLLM-side integration. The multi-profile scheduler and the hardcoded"prefill"profile name here are the seams llm-d plugs into. Study GIE first; llm-d then reads as “GIE plus opinionated vLLM deployment.” - vllm — the backend whose telemetry drives everything:
vllm:num_requests_waiting,vllm:kv_cache_usage_perc,vllm:lora_requests_infoare the EPP’s eyes (pkg/epp/server/options.go:105-109). The approximate prefix index mirrors vLLM’s block-hash chaining (prefix_cachingdesign) at 16-token granularity. KV pressure → preemption → latency cliff is the vLLM behavior the saturation detector encodes. - sgl-router & dynamo — the architectural counterpoint. Both put inference-aware routing in the data plane process (SGLang’s Rust router with cache-aware tree matching; Dynamo’s distributed runtime with its own KV-aware planner/router tier), getting tokenizer-exact prefix matching and event-driven cache state at the cost of owning the proxy: TLS, HTTP/2, retries, observability, deployment. GIE instead splits decision plane (EPP) from data plane (any conformant Envoy), paying one ext-proc RTT (~ms, plus full-duplex body streaming) and accepting approximate cache state, in exchange for vendor-neutrality and reuse of the Envoy ecosystem. Note GIE hedges on precision: char-based heuristic hashing here, with the event-driven precise indexer living in llm-d. Knowing both sides of this trade is exactly the judgment an infra interviewer probes.
- nano-vllm — the minimal lab for why these signals exist: see its block manager and scheduler to internalize what
kv_cache_usage_percandnum_requests_waitingphysically mean before tuning scorers that consume them. - xgrammar — orthogonal layer (constrained decoding inside the engine); only contact point is that structured-output requests skew decode cost, which token-load/latency scorers absorb statistically.
- flashinfer — the layer beneath vLLM that makes decode latency batch-size-dependent (paged KV attention kernels). That kernel-level fact is the root cause of “load-dependent latency” that makes least-request insufficient — the EPP is the system-level compensation.
7. Hands-on without a GPU fleet
Everything here runs on a laptop:
- Unit tests (
make test-unit) — pure Go, no cluster. Highest-value reads/runs:pkg/epp/scheduling/scheduler_test.go(full filter/score/pick cycles against fake metrics),approximateprefix/plugin_test.goandindexer_test.go(prefix matching end-to-end),saturationdetector/utilization/detector_test.go,pkg/epp/handlers/server_abort_test.go. - Hermetic integration tests (
make test-integration,test/integration/epp/hermetic_test.go) — boots the real EPP gRPC server via controller-runtimeenvtest(local kube-apiserver, no kubelets), injects pod metrics throughFakePodMetricsClient/ a mock datalayer source (test/integration/epp/harness.go:240-263), then drives the actual ext-proc stream with handcraftedProcessingRequests and asserts on the returned header mutations — i.e., you can watch a scheduling decision change as you flip a fake pod’sWaitingQueueSize. This is the best place to experiment with new scorer behavior. - Fake backends are first-class: the e2e suite and all getting-started guides default to the vLLM simulator
ghcr.io/llm-d/llm-d-inference-sim(config/manifests/vllm/sim-deployment.yaml) — a CPU-only container that speaks the OpenAI API and emits protocol-conformantvllm:*metrics, including LoRA. A kind cluster + any conformant gateway (kgateway, Istio, Envoy Gateway) + sim deployment gives you the full path:site-src/guides/index.mdis the walkthrough;make test-e2eautomates it (test/e2e/epp/README.md). - Standalone mode, no Gateway API at all (
site-src/guides/standalone.md): EPP + Envoy as a sidecar pair, pods discovered by--endpoint-selector app=...label selector — the minimal lab to watch ext-proc traffic with the simulator, and incidentally the deployment shape closest to “drop an inference scheduler into an existing Envoy fleet.” - Conformance against kind:
go test ./conformance --run TestConformancewith--gateway-classpointed at your gateway; echo-server backends only, no accelerators (site-src/guides/conformance-tests.md). - Flow control benchmarks:
pkg/epp/flowcontrol/benchmark/+make test-benchmarkfor queue/fairness micro-benchmarks.
Good first contribution surfaces (judged from TODOs and experimental markers in-tree, modulo the llm-d migration): the RequestContext protocol/lifecycle decoupling (pkg/epp/handlers/server.go:93-95 TODO), configurable prepareDataTimeout (pkg/epp/requestcontrol/director.go:52-55), the canonical P/D profile mechanism replacing the hardcoded "prefill" name (issue #2080), flow-control policy plugins, and conformance tests — which this repo explicitly retains and which are pure Gateway/Envoy engineering.