nano-vllm
A from-scratch reimplementation of vLLM’s core inference loop in ~1,450 lines of Python (version 0.2.0, commit bb823b3, “chunked-prefill-refactor”). One model family (Qwen3 dense), one platform (CUDA), offline batch only — but the four mechanisms that make modern LLM serving fast are all here, each in a single readable file: continuous batching, paged KV cache, hash-based prefix caching, and CUDA-graph decode. It even includes tensor parallelism over a homemade shared-memory RPC, and chunked prefill.
Why read this before real vLLM: vLLM V1 implements the same architecture — same queues, same block hashing, same prefill/decode batch preparation — but spread across hundreds of files with feature flags, backends, and abstraction layers. nano-vllm is the architecture with the adjectives removed. After this, opening vllm/v1/core/sched/scheduler.py feels like reading a familiar program with more options, not an alien codebase. The author benchmarks it slightly faster than vLLM on small models (README.md: 1434 vs 1362 tok/s, Qwen3-0.6B on an 8GB laptop GPU), so it is not a toy that cheats.
Why you care
- The scheduler is an admission controller.
Scheduler.schedule()is 49 lines that decide, every iteration, which requests get GPU time against a token budget and a block pool — structurally the same problem as adaptive concurrency / load shedding in Envoy, except the contended resource is KV-cache blocks instead of connections. Preemption here is literally load shedding with retry-from-front-of-queue. - Prefix caching is content-addressable storage with refcounts. Once you’ve read
BlockManager.hash_blocks, the “KV-cache-aware routing” in llm-d, dynamo, and gateway-api-inference-extension stops being marketing: those routers are trying to predict, from outside the engine, whathash_to_block_idcontains inside each replica. - It defines the signal set. Queue depths, free blocks, prefill-vs-decode token throughput, preemption events — everything a serving-layer LB wants exported from an engine is a one-line
printaway in this codebase. You can watch the whole control plane think in a terminal.
Complete file map
| File | Lines | Purpose |
|---|---|---|
nanovllm/__init__.py | 2 | Exports LLM and SamplingParams |
nanovllm/llm.py | 5 | class LLM(LLMEngine): pass — the vLLM-compatible facade |
nanovllm/config.py | 25 | Config dataclass: budgets, block size, TP size; pulls HF config; asserts model path is a local dir |
nanovllm/sampling_params.py | 11 | temperature / max_tokens / ignore_eos — that’s the entire sampling API (greedy is asserted away) |
nanovllm/engine/llm_engine.py | 90 | Engine: spawns TP worker processes, owns tokenizer + scheduler, runs the step() loop, tracks tok/s |
nanovllm/engine/scheduler.py | 92 | Continuous batching: waiting/running deques, chunked prefill, preemption, finish detection |
nanovllm/engine/block_manager.py | 120 | Paged KV allocator + prefix cache via chained xxhash64 block hashing, refcounts, lazy eviction |
nanovllm/engine/sequence.py | 83 | Sequence: per-request state machine (WAITING/RUNNING/FINISHED), block table, token counters |
nanovllm/engine/model_runner.py | 257 | Per-GPU worker: KV memory sizing, batch tensor prep, CUDA graph capture/replay, sampling, shm RPC for TP |
nanovllm/layers/attention.py | 75 | Triton kernel writing K/V into paged cache + flash-attn varlen (prefill) / with-kvcache (decode) |
nanovllm/layers/sampler.py | 12 | Temperature sampling via the Gumbel/exponential trick, torch.compiled |
nanovllm/layers/linear.py | 156 | TP linear layers (Column/Row/QKV/Merged) with per-param weight_loader — vLLM’s loading idiom |
nanovllm/layers/embed_head.py | 66 | Vocab-parallel embedding; LM head that slices only each sequence’s last position before projecting |
nanovllm/layers/rotary_embedding.py | 59 | Precomputed cos/sin RoPE cache |
nanovllm/layers/layernorm.py | 50 | RMSNorm with fused residual-add variant |
nanovllm/layers/activation.py | 11 | SiLU-and-mul (SwiGLU gate) |
nanovllm/models/qwen3.py | 216 | Qwen3 decoder assembled from the layers above; packed_modules_mapping for fused-weight loading |
nanovllm/utils/context.py | 27 | Process-global Context dataclass carrying attention metadata into the forward pass |
nanovllm/utils/loader.py | 28 | Safetensors loader dispatching to per-parameter weight_loaders |
example.py | 33 | Chat-template generation demo (Qwen3-0.6B) |
bench.py | 32 | 256-sequence random-length throughput benchmark, vLLM-comparable |
Core mechanisms
The engine step loop
LLMEngine.generate() (nanovllm/engine/llm_engine.py:60-90) enqueues all prompts, then spins step() until both queues drain. Each step is one scheduler decision plus one GPU forward:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
num_tokens = sum(seq.num_scheduled_tokens for seq in seqs) if is_prefill else -len(seqs)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids, is_prefill)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
return outputs, num_tokens
(nanovllm/engine/llm_engine.py:49-55.) The key design choice: every batch is either pure prefill or pure decode — schedule() returns a flag, and the model runner builds entirely different tensors for each case. Prefill has strict priority: decode only runs when the waiting queue is empty or blocked on memory. (vLLM V1 abandoned this dichotomy — see the comparison table.) The num_tokens sign trick (positive = prefill tokens, negative = decode batch size) only feeds the tqdm throughput display.
Block manager: paged KV + prefix caching
BlockManager (nanovllm/engine/block_manager.py) owns num_kvcache_blocks logical blocks of 256 tokens each (vLLM defaults to 16). A Block is {block_id, ref_count, hash, token_ids}. Prefix caching is a chained content hash — each block’s hash commits to the entire prefix before it, so one dict lookup matches an arbitrarily long prefix:
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
(nanovllm/engine/block_manager.py:35-41.) On admission, can_allocate (nanovllm/engine/block_manager.py:58-73) walks the prompt’s full blocks, following the hash chain until the first miss:
h = -1
num_cached_blocks = 0
num_new_blocks = seq.num_blocks
for i in range(seq.num_blocks - 1):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
break
num_cached_blocks += 1
if block_id in self.used_block_ids:
num_new_blocks -= 1
if len(self.free_block_ids) < num_new_blocks:
return -1
return num_cached_blocks
Three details worth noticing: the stored token_ids comparison guards against hash collisions; the loop deliberately stops at num_blocks - 1 so the last (possibly partial) block is always recomputed — which also guarantees at least one query token exists to produce logits from; and -1 means “not enough memory, stall admission.”
allocate (nanovllm/engine/block_manager.py:75-92) then bumps refcounts on hits — including resurrecting blocks that are sitting on the free list but still hold valid data. That’s the eviction policy: deallocation just appends to the free deque without clearing the hash, so a “freed” block remains a cache hit until _allocate_block (nanovllm/engine/block_manager.py:43-51) actually reuses it and unregisters its hash:
for i in range(num_cached_blocks):
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block_id = self.hash_to_block_id[h]
block = self.blocks[block_id]
if block_id in self.used_block_ids:
block.ref_count += 1
else:
block.ref_count = 1
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
During decode, growth is incremental — one new block exactly when the latest token crossed a block boundary, with a bool-as-int trick:
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence):
if len(seq) % self.block_size == 1:
seq.block_table.append(self._allocate_block())
(nanovllm/engine/block_manager.py:103-108.) Finally, after every model step, hash_blocks registers any block that just became full — note this runs during decode too, so generated tokens become prefix-cacheable (multi-turn reuse for free):
def hash_blocks(self, seq: Sequence):
start = seq.num_cached_tokens // self.block_size
end = (seq.num_cached_tokens + seq.num_scheduled_tokens) // self.block_size
if start == end: return
h = self.blocks[seq.block_table[start - 1]].hash if start > 0 else -1
for i in range(start, end):
block = self.blocks[seq.block_table[i]]
token_ids = seq.block(i)
h = self.compute_hash(token_ids, h)
block.update(h, token_ids)
self.hash_to_block_id[h] = block.block_id
(nanovllm/engine/block_manager.py:110-120.)
Continuous batching, chunked prefill, preemption
Scheduler (nanovllm/engine/scheduler.py) keeps two deques, waiting and running. The prefill pass (nanovllm/engine/scheduler.py:30-52) admits from waiting FCFS while under max_num_seqs (512) and max_num_batched_tokens (16384), discounting prefix-cache hits from the work and chunking long prompts:
if not seq.block_table:
num_cached_blocks = self.block_manager.can_allocate(seq)
if num_cached_blocks == -1:
break
num_tokens = seq.num_tokens - num_cached_blocks * self.block_size
else:
num_tokens = seq.num_tokens - seq.num_cached_tokens
if remaining < num_tokens and scheduled_seqs: # only allow chunked prefill for the first seq
break
if not seq.block_table:
self.block_manager.allocate(seq, num_cached_blocks)
seq.num_scheduled_tokens = min(num_tokens, remaining)
(nanovllm/engine/scheduler.py:35-46.) A sequence whose prompt doesn’t fit in the remaining budget gets a partial chunk (only as the first seq of the batch) and stays in waiting with its block_table already fully allocated; it’s promoted to running only once the whole prompt has been scheduled (nanovllm/engine/scheduler.py:48-51). Blocks for the entire prompt are allocated up front at admission — simpler than vLLM, which allocates as chunks proceed.
If nothing was admitted, the decode pass runs — and this is where requests leave the batch under memory pressure:
while self.running and len(scheduled_seqs) < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
seq.num_scheduled_tokens = 1
seq.is_prefill = False
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
(nanovllm/engine/scheduler.py:58-72.) Victim selection is self.running.pop() — the youngest running sequence is sacrificed for the oldest (LIFO shedding, FCFS service). preempt (nanovllm/engine/scheduler.py:75-79) frees all the victim’s blocks and pushes it to the front of waiting: preemption is full recompute, no CPU swap — softened in practice because its full blocks are still hash-registered, so re-prefill is mostly cache hits if they haven’t been reused. The while/else is Python’s no-break clause; the extendleft(reversed(...)) returns scheduled seqs to the front in original order.
postprocess (nanovllm/engine/scheduler.py:81-92) closes the loop — registers new block hashes, advances counters, appends the sampled token (discarding it for mid-prompt chunks), and retires finished sequences:
for seq, token_id in zip(seqs, token_ids):
self.block_manager.hash_blocks(seq)
seq.num_cached_tokens += seq.num_scheduled_tokens
seq.num_scheduled_tokens = 0
if is_prefill and seq.num_cached_tokens < seq.num_tokens:
continue
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)
That continue is chunked prefill’s tell: the sampler always samples, but the token is thrown away unless this chunk completed the prompt.
Model runner: forward pass, CUDA graphs, sampling
ModelRunner (nanovllm/engine/model_runner.py) is one process per TP rank. Rank 0 lives in the engine process; ranks 1..N-1 are spawned (nanovllm/engine/llm_engine.py:24-31) and block in loop() reading method calls from a 1 MB SharedMemory segment signaled by mp.Event (nanovllm/engine/model_runner.py:61-89) — a ~30-line RPC standing in for vLLM’s ZMQ executor. Sequence.__getstate__ (nanovllm/engine/sequence.py:72-74) keeps that pickle tiny: decode steps ship only the last token, not the whole sequence.
At startup it runs a max-size dummy prefill to measure peak activation memory, then sizes the KV cache to fill whatever is left and aliases slices of one giant tensor into every attention layer:
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
(nanovllm/engine/model_runner.py:110-121.) Per step, run() (nanovllm/engine/model_runner.py:214-220) builds tensors via prepare_prefill (ragged varlen batch: flattened input_ids, cu_seqlens_q/k, and a slot_mapping from each token to its physical cache slot — nanovllm/engine/model_runner.py:129-170) or prepare_decode — one token per sequence:
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
(nanovllm/engine/model_runner.py:177-181.) The metadata travels to the attention layers via a process-global Context (nanovllm/utils/context.py:21-23) — nano’s version of vLLM’s ForwardContext, sidestepping threading metadata through nn.Module.forward signatures.
Decode steps with batch ≤ 512 replay a pre-captured CUDA graph instead of relaunching kernels (capture_cudagraph at nanovllm/engine/model_runner.py:222-257 captures graphs for batch sizes 1, 2, 4, 8, 16, 32, … sharing one memory pool). Replay = copy inputs into static buffers, fire:
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
(nanovllm/engine/model_runner.py:200-212.) The slot_mapping.fill_(-1) padding is honored by the Triton store kernel, which skips slot −1 (nanovllm/layers/attention.py:23). Prefill always runs eager — graphs only pay off for the many small identical decode launches.
Inside each attention layer, the whole paged-attention story is two flash-attn calls plus the cache-write kernel:
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
(nanovllm/layers/attention.py:62-74.) Sampling happens on rank 0 only — ParallelLMHead first slices each sequence’s last position (nanovllm/layers/embed_head.py:58-60) so logits are computed for one token per sequence, TP shards are gathered to rank 0 (nanovllm/layers/embed_head.py:62-65), then the entire sampler is six lines of Gumbel-max:
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens
(nanovllm/layers/sampler.py:7-12.) Dividing probabilities by i.i.d. Exponential(1) noise and taking argmax is an exact categorical sample — one fused kernel, no top-k/top-p machinery.
What a Sequence carries
Sequence (nanovllm/engine/sequence.py) is the unit the scheduler moves around: token list, num_prompt_tokens, num_cached_tokens (KV entries that already exist — from prefix hits, completed chunks, and each decode step), num_scheduled_tokens (work assigned this step), its block_table (logical→physical block ids), and the three sampling fields. The class-level block_size is injected from config at engine startup (nanovllm/engine/llm_engine.py:21).
add_request() schedule(): blocks allocated, postprocess(): EOS or
┌─────────┐ prompt chunks scheduled ┌─────────┐ max_tokens reached ┌──────────┐
arrival ──►│ WAITING │ ──────────────────────────────►│ RUNNING │────────────────────►│ FINISHED │
└─────────┘ └─────────┘ (blocks freed, └──────────┘
▲ │ seq leaves running)
└──────── preempt(): ALL blocks freed ◄────┘
re-enters at FRONT of waiting; re-prefill is
cheap if its blocks are still in the hash map
There is no SWAPPED state (nanovllm/engine/sequence.py:8-11) — vLLM V0 had one; nano (like vLLM V1’s default) preempts by recompute.
nano-vllm → vLLM V1 map
| nano-vllm | vLLM V1 equivalent | Delta |
|---|---|---|
LLMEngine (nanovllm/engine/llm_engine.py) | LLMEngine / AsyncLLM + EngineCore (vllm/v1/engine/llm_engine.py, core.py) | vLLM splits the core loop into its own process behind ZMQ; adds async streaming |
Scheduler (nanovllm/engine/scheduler.py) | Scheduler (vllm/v1/core/sched/scheduler.py) | V1 erased the prefill/decode dichotomy: one token-budget loop assigns num_scheduled_tokens per request and mixes both in a single batch; adds priority, spec-decode tokens, structured-output gating |
waiting / running deques | same-named queues in V1’s scheduler | V1 adds priority policies; same FCFS default |
Sequence (nanovllm/engine/sequence.py) | Request (vllm/v1/request.py) + GPU-side CachedRequestState (vllm/v1/worker/gpu_input_batch.py) | V0’s Sequence/SequenceGroup is the historical ancestor |
BlockManager (nanovllm/engine/block_manager.py) | KVCacheManager (vllm/v1/core/kv_cache_manager.py) + BlockPool (block_pool.py), hashing in kv_cache_utils.py | Same chained block hashing; vLLM adds LRU eviction, hybrid KV groups (sliding window etc.), and an explicit free-queue order |
num_cached_tokens / chunk accounting | num_computed_tokens on Request | Same concept, same role in chunked prefill |
ModelRunner (nanovllm/engine/model_runner.py) | GPUModelRunner + Worker (vllm/v1/worker/gpu_model_runner.py, gpu_worker.py) | vLLM keeps a persistent on-GPU input batch updated by diffs instead of rebuilding tensors per step |
SharedMemory + Event RPC (nanovllm/engine/model_runner.py:61-89) | MultiprocExecutor (vllm/v1/executor/multiproc_executor.py) shm message queues; Ray for multi-node | Same shape: rank 0 broadcasts method calls to workers |
Context global (nanovllm/utils/context.py) | ForwardContext (vllm/forward_context.py) + per-backend AttentionMetadata | Identical trick: process-global side channel into attention |
flash-attn calls (nanovllm/layers/attention.py:67-74) | pluggable backends (vllm/v1/attention/backends/: FlashAttention, FlashInfer, Triton, FlexAttention…) | nano hardcodes one backend |
capture_cudagraph (nanovllm/engine/model_runner.py:222-257) | full + piecewise CUDA graphs via torch.compile (vllm/compilation/) | piecewise = graph the transformer chunks, leave attention eager |
Sampler (nanovllm/layers/sampler.py) | Sampler (vllm/v1/sample/sampler.py) | vLLM adds top-k/p, penalties, logprobs, seeds, bad-words, structured-output bitmasks |
loader.py + weight_loader attrs | vllm/model_executor/model_loader/ | Same per-parameter loader idiom, same packed_modules_mapping for fused QKV/gate-up |
What’s deliberately omitted
Do not expect to find any of the following — knowing they’re absent is what keeps the codebase at 1.4k lines:
- Serving: no API server, no OpenAI endpoints, no streaming, no async engine —
generate()is blocking batch; detokenization happens once at the end (nanovllm/engine/llm_engine.py:89). - Sampling: no top-k/top-p, penalties, logprobs, stop strings, seeds, beam search — and greedy decoding is explicitly forbidden (
nanovllm/sampling_params.py:11). - Structured output: no grammar/JSON-schema constrained decoding (no logits-processor hook at all).
- Quantization: none (no GPTQ/AWQ/FP8) — checkpoints load at the HF config dtype, period.
- Adapters & spec decode: no LoRA/multi-LoRA, no speculative decoding (EAGLE/Medusa/ngram).
- Models: Qwen3 dense only — no MoE, no multimodal, no sliding-window/hybrid attention, no encoder-decoder.
- Parallelism: single-node TP only (≤8, NCCL hardcoded); no pipeline/expert/data parallel, no Ray, no multi-node.
- KV mobility: preemption is recompute-only — no CPU swap, no KV offload tiers, no KV connectors / disaggregated prefill-decode transfer.
- Operations: no metrics, tracing, health checks, or scheduling policies beyond FCFS.
- Portability: CUDA only — no ROCm/CPU/TPU backends.
Connections to your other study repos
- vllm — direct target. Read nano’s scheduler, then
vllm/v1/core/sched/scheduler.py; nano’sBlockManager, thenvllm/v1/core/kv_cache_manager.py+block_pool.py. The single most instructive diff: V1 has nois_prefillflag anywhere in scheduling — convince yourself nano’s two-phaseschedule()collapses into V1’s one token-budget loop. - sglang — same engine loop, different prefix-cache data structure: RadixAttention keeps a radix tree over token sequences (
python/sglang/srt/mem_cache/radix_cache.py) where nano keeps a flat hash map; the tree enables cache-aware scheduling (reordering the waiting queue by match length — contrast nano’s strict FCFS) and partial-block matching. sglang’sSchedulerispython/sglang/srt/managers/scheduler.py. - flashinfer — nano’s entire kernel contract is three calls: the Triton cache-write plus
flash_attn_varlen_func/flash_attn_with_kvcache(nanovllm/layers/attention.py). flashinfer provides exactly those two attention ops (BatchPrefill/BatchDecode over paged KV) with a plan/run split and JIT specialization — it’s what vLLM/sglang slot into the seam nano hardcodes. - xgrammar — would plug in between
compute_logitsand the sampler (nanovllm/engine/model_runner.py:217-218) as a per-sequence token bitmask applied to logits. nano having no hook there is the lesson: you can see precisely where structured output costs latency and why vLLM gates “waiting for FSM compilation” in the scheduler. - dynamo — nano’s strict per-step prefill/decode mutual exclusion is the single-GPU shadow of disaggregated P/D serving: dynamo moves that split across worker pools and ships KV via NIXL. Its KV-aware router consumes events about exactly the state that is
hash_to_block_idhere. - llm-d / gateway-api-inference-extension — these balance across replicas of engines like this one. The EPP scorers (prefix-cache-aware, queue-depth, load) are estimating, from outside, what
can_allocatewould return and how longwaitingis inside each pod. The instrumentation section below prints precisely the signal set an inference gateway wants engines to export.
Running it on your RTX 5080 (16 GB)
Use WSL2, not native Windows. Three hard blockers natively: dist.init_process_group("nccl", ...) is hardcoded even for TP=1 (nanovllm/engine/model_runner.py:26) and Windows torch builds ship no NCCL; Triton has no official Windows wheels (the store_kvcache kernel and torch.compile need it); flash-attn has no official Windows wheels. All three work in WSL2 Ubuntu with the standard CUDA-on-WSL driver. (Fallback hack if you must stay native: locally change "nccl" to "gloo", install triton-windows, and hunt a community flash-attn wheel — WSL2 is less total work.)
Blackwell specifics. The 5080 is sm_120: you need a cu128 build of PyTorch (≥2.7) and a flash-attn built against it — there’s frequently no prebuilt sm_120 wheel, so expect a source build:
pip install torch --index-url https://download.pytorch.org/whl/cu128
MAX_JOBS=4 TORCH_CUDA_ARCH_LIST="12.0" pip install flash-attn --no-build-isolation # long compile
pip install transformers xxhash huggingface_hub
huggingface-cli download Qwen/Qwen3-4B --local-dir ~/huggingface/Qwen3-4B/
Config asserts the model path is a local directory (nanovllm/config.py:21) — you cannot pass a hub id.
Model choice. No quantization support means weights load in bf16, so the ceiling is lower than you’d guess. KV cost per 256-token block is 2 * layers * 256 * kv_heads * head_dim * 2 bytes (nanovllm/engine/model_runner.py:112):
| Model | bf16 weights | KV per token | KV per block | Fits 16 GB? |
|---|---|---|---|---|
| Qwen3-0.6B | ~1.4 GB | 112 KiB | 28 MiB | Yes — huge cache (~100k tokens), best for scheduler experiments |
| Qwen3-1.7B | ~3.4 GB | 112 KiB | 28 MiB | Yes, comfortable |
| Qwen3-4B | ~8 GB | 144 KiB | 36 MiB | Yes — the sweet spot; roughly 130–160 blocks (~35–40k KV tokens) left |
| Qwen3-8B | ~16.4 GB | — | — | No. Quantized 8B isn’t possible here; do Qwen3-8B-AWQ in real vLLM instead |
Start with enforce_eager=True (skips capturing ~36 CUDA graphs; far friendlier stack traces), gpu_memory_utilization=0.8 if your desktop compositor holds VRAM, and check what the profiler decided: llm.config isn’t exposed, but llm.scheduler.block_manager.free_block_ids length right after construction equals num_kvcache_blocks.
Watching the scheduler think. Without touching repo files, monkeypatch from your own driver script:
from nanovllm import LLM, SamplingParams
from nanovllm.engine.scheduler import Scheduler
orig_schedule, orig_preempt = Scheduler.schedule, Scheduler.preempt
def schedule(self):
seqs, is_prefill = orig_schedule(self)
print(f"{'PREFILL' if is_prefill else 'DECODE '}"
f" batch={len(seqs)}"
f" sched_tokens={sum(s.num_scheduled_tokens for s in seqs)}"
f" waiting={len(self.waiting)} running={len(self.running)}"
f" free_blocks={len(self.block_manager.free_block_ids)}"
f" cached_hashes={len(self.block_manager.hash_to_block_id)}")
return seqs, is_prefill
def preempt(self, seq):
print(f" PREEMPT seq={seq.seq_id} len={len(seq)} (recompute later)")
orig_preempt(self, seq)
Scheduler.schedule, Scheduler.preempt = schedule, preempt
Experiments that make each mechanism visible:
- Prefix cache: call
generate()twice with the same long prompt (>512 tokens, so it spans ≥2 full blocks). The second run’s PREFILL line showssched_tokensshrunk by 256 per matched block — andseq.num_cached_tokensconfirms the hits. Note hits land only across steps/calls: two identical prompts admitted in the same prefill batch don’t dedupe, becausehash_blocksruns in postprocess. - Chunked prefill: set
max_num_batched_tokens=1024, send a ~4k-token prompt — successive PREFILL steps of 1024 for one sequence, which joinsrunningonly on the last chunk. - Preemption: shrink the pool (
gpu_memory_utilization=0.5) and runbench.py-style load — many sequences withignore_eos=True,max_tokens=1024. Whenfree_blockshits 0 mid-decode, watch PREEMPT pick the youngest sequences and note their later re-prefill is cheap (cached_hashes still holds their blocks). - Continuous batching itself: give 32 prompts wildly different
max_tokens; the DECODEbatch=count steps down as short sequences finish while long ones keep going — no stragglers waiting on a static batch, no head-of-line blocking. - CUDA graphs: flip
enforce_eagerand compare the tqdm decode tok/s; then readrun_model(nanovllm/engine/model_runner.py:196-212) knowing exactly what you paid for.