SGLang
1. What it is, and how it differs from vLLM
SGLang is an LLM serving stack from the LMSYS team with two load-bearing halves: a Python serving runtime (“SRT”, python/sglang/srt/) that owns the GPU — continuous batching, paged KV, RadixAttention prefix caching, grammar-constrained decoding — and a production Rust router/data plane (sgl-model-gateway/, historically named sgl-router) that load-balances a fleet of those runtimes. Versus vLLM: both do continuous batching + paged KV + prefix caching, but SGLang’s prefix cache is a first-class radix tree over token IDs (vLLM v1 hashes fixed-size blocks; SGLang matches at arbitrary token granularity within page alignment and schedules around the tree — its waiting-queue policy literally sorts by longest-prefix-match against cache contents), its scheduler overlap (“zero-overhead scheduler”) runs CPU bookkeeping for batch N concurrently with GPU forward of batch N+1 by default, it retracts decode requests back to the waiting queue under memory pressure rather than vLLM-style preemption semantics, and — the part that matters for you — it ships a cross-replica cache-aware router as a supported component rather than leaving fleet-level routing to an external gateway. The repo also carries sgl-kernel/ (AOT CUDA kernels), a frontend DSL (python/sglang/lang/ — the original “Structured Generation Language” that gave the project its name; co-pilot-style programs with gen() calls; you can skip it), and a gRPC protocol (proto/, rust/sglang-grpc/) that lets the Rust gateway talk to schedulers while doing tokenization in Rust.
2. Why you care
You sketched cache-aware routing with prefix trees before you knew this existed. sgl-router is the production version of that exact sketch, and it validates every design intuition you’d have brought from Envoy land: the router keeps an approximate radix tree per worker, built purely from routing history (no worker telemetry needed), routes to the worker with the highest prefix match when load is balanced, and falls back to shortest-queue when an imbalance detector trips — with the imbalance test being the same abs+rel two-threshold pattern you’d use for outlier detection in an LB. The entire policy is ~500 lines of Rust in sgl-model-gateway/src/policies/cache_aware.rs plus a ~2300-line lock-free radix tree in tree.rs, and reading it is the fastest way to see which corners production cut that your sketch probably didn’t (char-level text instead of tokens, optimistic insert-at-selection-time, zero feedback loop from workers). It also gives you the counterpoint to study against NVIDIA Dynamo’s KV-event-based “ground truth” approach and the Envoy-based gateway-api-inference-extension — three points on the design space of “how much does the LB know about cache state, and who pays for that knowledge.”
3. Architecture map
sglang/ (repo root)
├── python/sglang/
│ ├── lang/ # frontend DSL (skippable)
│ ├── launch_server.py # CLI: python -m sglang.launch_server
│ ├── bench_serving.py # load generator (has shared-prefix dataset)
│ └── srt/ # THE RUNTIME
│ ├── entrypoints/http_server.py # FastAPI app (OpenAI-compatible + /generate)
│ ├── entrypoints/engine.py # spawns the 3-process pipeline
│ ├── managers/tokenizer_manager.py # proc 1: tokenize, track request state
│ ├── managers/scheduler.py # proc 2 (per TP rank): batching, KV, the GPU loop
│ ├── managers/schedule_batch.py # Req/ScheduleBatch state machines
│ ├── managers/schedule_policy.py # waiting-queue priority (LPM) + PrefillAdder
│ ├── managers/detokenizer_manager.py # proc 3: incremental detokenization
│ ├── mem_cache/radix_cache.py # RadixAttention's tree (the real one)
│ ├── mem_cache/memory_pool.py # req_to_token + token_to_kv paged pools
│ ├── layers/radix_attention.py # nn.Module that dispatches to attn backends
│ ├── layers/attention/ # flashinfer_backend.py, triton_backend.py, ...
│ ├── constrained/ # xgrammar/llguidance/outlines bitmask backends
│ └── disaggregation/ # prefill/decode (PD) split runtime
├── sgl-model-gateway/ # THE ROUTER (formerly sgl-router; pip: sglang-router)
│ ├── src/policies/ # cache_aware.rs, tree.rs, power_of_two.rs, ...
│ ├── src/core/ # worker.rs (health/load/CB), worker_registry, retry
│ ├── src/routers/http/ # router.rs (regular), pd_router.rs (PD pairs)
│ ├── src/routers/grpc/ # Rust-native tokenizer pipeline → SRT gRPC workers
│ └── bindings/python/ # python -m sglang_router.launch_router
├── sgl-kernel/ # AOT CUDA/C++ kernels (own wheel)
└── proto/, rust/sglang-grpc/ # gRPC contract between gateway and scheduler
Process pipeline inside one worker: HTTP → TokenizerManager → ZMQ → Scheduler (one process per GPU; owns model + KV) → ZMQ → DetokenizerManager → ZMQ back to TokenizerManager → SSE out. Spawning is in python/sglang/srt/entrypoints/engine.py:584 (scheduler procs) and :196 (process wiring). Note ZMQ IPC, not gRPC, between these processes — the gRPC path is an alternative front door used by the Rust gateway (python/sglang/srt/entrypoints/grpc_server.py).
4. Core mechanisms
4.1 sgl-router: cache-aware load balancing (Rust)
Everything below lives under sgl-model-gateway/src/. The policy abstraction is LoadBalancingPolicy (policies/mod.rs:42) — select_worker(workers, info) -> Option<usize> where SelectWorkerInfo (policies/mod.rs:160) carries request_text (HTTP path), tokens (gRPC path), headers, and a consistent-hash ring. Policies: random, round_robin, cache_aware, power_of_two, prefix_hash, consistent_hashing, bucket, manual. Default is cache_aware (src/main.rs:150).
The routing decision
CacheAwarePolicy::select_worker (policies/cache_aware.rs:376) is a two-mode policy. First it computes min/max in-flight load across workers and runs the imbalance detector:
#![allow(unused)]
fn main() {
// Check if load is imbalanced
let is_imbalanced = max_load.saturating_sub(min_load) > self.config.balance_abs_threshold
&& (max_load as f32) > (min_load as f32 * self.config.balance_rel_threshold);
if is_imbalanced {
return self.select_worker_min_load(
workers,
&request_text,
&healthy_indices,
&tree_key,
max_load,
min_load,
);
}
}
— policies/cache_aware.rs:400-413. Both conditions must hold (abs gap AND ratio), the same belt-and-suspenders you’d use so that tiny absolute deltas at low traffic and proportionally-small deltas at high traffic both stay in cache-affinity mode. CLI defaults: cache_threshold 0.3, balance_abs_threshold 64, balance_rel_threshold 1.5, eviction every 120s, max_tree_size 67108864 chars (src/main.rs:153-171; the in-code CacheAwareConfig::default() at policies/mod.rs:106-116 differs — CLI wins).
When balanced, it consults the approximate tree:
#![allow(unused)]
fn main() {
let result = tree.prefix_match_with_counts(text);
let match_rate = if result.input_char_count == 0 {
0.0
} else {
result.matched_char_count as f32 / result.input_char_count as f32
};
// Select worker without String allocation
let selected_idx = if match_rate > self.config.cache_threshold {
// Cache hit path: find worker by URL (compare &str directly, no allocation)
let tenant_url: &str = &result.tenant;
workers
.iter()
.position(|w| w.url() == tenant_url)
.filter(|&idx| workers[idx].is_healthy())
}
— policies/cache_aware.rs:425-439. Below the threshold, the else-branch (:440-446) picks the min-load healthy worker — i.e., “no one has meaningful affinity for this prefix, so place it where there’s most headroom” (the comment block at :29-30 frames it as “smallest tree = most available cache capacity”; the implementation uses live load). Then the critical learning step — the router inserts the request text under the chosen worker at selection time:
#![allow(unused)]
fn main() {
if let Some(idx) = selected_idx {
// Update the tree with this request (use worker URL directly, no allocation)
tree.insert(text, workers[idx].url());
}
— policies/cache_aware.rs:448-450. That insert is what builds each worker’s “approximate cache state”: the router assumes that whatever it routes to worker W will end up in W’s real radix cache. Even in imbalanced/min-load mode it still inserts (:335-343), so affinity keeps tracking reality during load-balancing episodes. If the matched tenant is dead, the policy removes it from the tree and falls back (:471-491).
The text being matched comes from extract_text_for_routing() on the typed request (routers/http/router.rs:203) — raw prompt/messages text, characters not tokens. The gRPC router can instead route on token IDs via PrefixHashPolicy (policies/prefix_hash.rs, hashing the first --prefix-token-count 256 tokens to a consistent-hash ring — a cheaper, stateless cousin of cache-aware).
The approximate radix tree
policies/tree.rs is a concurrent multi-tenant radix tree: one tree shared by all workers of a (pool, model) group, where each node tracks which workers (tenants) own that prefix segment:
#![allow(unused)]
fn main() {
#[derive(Debug)]
struct Node {
/// Children nodes indexed by first character.
/// Using custom hasher optimized for char keys.
children: DashMap<char, NodeRef, CharHasherBuilder>,
/// Node text with cached character count
text: RwLock<NodeText>,
/// Per-tenant last access epoch for LRU ordering. Using TenantId (Arc<str>) for cheap cloning.
tenant_last_access_time: DashMap<TenantId, u64>,
/// Parent pointer for upward traversal during timestamp updates
parent: RwLock<Option<NodeRef>>,
/// Cached last-accessed tenant for O(1) lookup during prefix match.
/// Avoids O(shards) DashMap iteration in the common case.
last_tenant: parking_lot::RwLock<Option<TenantId>>,
}
}
— policies/tree.rs:229-243. Notes a traffic-infra person will appreciate: TenantId = Arc<str> interning (:40), per-node DashMaps with tuned shard counts (32 for root, 8 elsewhere, :23-24), and LRU timestamps from a global atomic epoch counter instead of wall clock (:218-227 — “no syscall overhead… perfectly monotonic”). insert (:362) walks char-by-char, splitting nodes on partial matches (:435-482) like any radix trie; prefix_match_with_counts (:531) walks down and returns the deepest node’s tenant. Two deliberate approximations in the hot path: tenant resolution uses the last_tenant cache or whatever tenant the DashMap iterator yields first (:568-599) — when multiple workers own the same prefix it does not pick the least-loaded owner — and LRU timestamps update only 1-in-8 matches:
#![allow(unused)]
fn main() {
// Update timestamp probabilistically (1 in 8 matches) to reduce DashMap contention.
// LRU eviction doesn't need perfect accuracy - approximate timestamps suffice.
let epoch = get_epoch();
if epoch & 0x7 == 0 {
curr.tenant_last_access_time
.insert(Arc::clone(&tenant), epoch);
}
}
— policies/tree.rs:601-607.
Trees are keyed pool::model ("prefill::llama" / "decode::llama" / "regular::UNKNOWN_MODEL_ID", policies/cache_aware.rs:83-100) so in PD mode the prefill and decode pools can’t evict each other’s tenants — the doc comment at :76-82 records the regression where a shared tree collapsed into pool flip-flopping. Workers are seeded with an empty-string insert at registration (:185, :197).
Tree maintenance and eviction
A background thread runs every eviction_interval_secs and trims each tree to a per-tenant character budget (policies/cache_aware.rs:127-149). The eviction itself (policies/tree.rs:718) collects per-tenant leaves into a min-heap by LRU epoch and peels leaves until the tenant is under budget:
#![allow(unused)]
fn main() {
// Process eviction
while let Some(Reverse(entry)) = pq.pop() {
let EvictionEntry { tenant, node, .. } = entry;
if let Some(used_size) = self.tenant_char_count.get(tenant.as_ref()) {
if *used_size <= max_size {
continue;
}
}
}
— policies/tree.rs:745-753; when a leaf’s removal makes its parent a new leaf for that tenant, the parent is pushed back into the heap (:793-814), so eviction unwinds bottom-up exactly like the runtime’s real tree. remove_tenant (:823) does the same unwinding for worker removal. This mirrors your sketch’s “evict LRU leaves” but note the unit: characters per tenant, with no relationship to the worker’s actual KV-token budget.
Worker health, load, and failure handling
core/worker.rs is the Envoy-cluster-management analog. Each worker has an AtomicBool health flag, an atomic in-flight load counter, a circuit breaker, and a background health checker hitting its health endpoint. Health transitions use hysteresis — failure_threshold consecutive failures to go down, success_threshold consecutive successes to come back:
#![allow(unused)]
fn main() {
if self.is_healthy()
&& failures >= self.metadata.health_config.failure_threshold as usize
{
self.set_healthy(false);
self.consecutive_failures.store(0, Ordering::Release);
}
}
— core/worker.rs:740-745 (the symmetric success path is :726-731). Policy-visible “load” is the router’s own in-flight counter, maintained RAII-style — and only for policies that need it:
#![allow(unused)]
fn main() {
let load_guard = ["cache_aware", "manual"]
.contains(&policy.name())
.then(|| WorkerLoadGuard::new(worker.clone(), headers));
}
— routers/http/router.rs:300-302. WorkerLoadGuard::new increments; Drop decrements (core/worker.rs:1109-1140), and the guard is attached to the streaming response body (AttachedBody, :1147) so load stays held until the SSE stream actually finishes — the same trick as deferred-done callbacks in Envoy. Policy eligibility = healthy AND circuit breaker closed (policies/mod.rs:136-143). Separately, a LoadMonitor (core/worker_manager.rs:269) polls each worker’s /v1/loads?include=core and extracts aggregate.total_tokens (:207-230) — but only power_of_two consumes that token-level signal (policies/power_of_two.rs:113-117, with graceful degradation to request counts at :75-85). Cache-aware never sees worker-reported load.
Where the approximation can go stale
This is the part to study hardest, because it’s the gap between your sketch and what you’d need to productionize one:
- Characters vs tokens vs pages. Router matches raw text chars; the worker’s real tree matches token IDs, page-aligned (
radix_cache.py:380). Chat-template expansion, tokenizer merges, and page truncation all makematch_ratean estimate. Good enough to pick a worker; never exact. - Optimistic insert, no confirmation. The tree is updated at selection time (
cache_aware.rs:450) — before the worker even accepts the request. If the request fails or the retry layer (routers/http/router.rs:217) re-routes it, the losing worker keeps phantom affinity until LRU eviction.on_request_completeis explicitly a no-op (cache_aware.rs:507-517— “Could track success rates per worker”). - Independent eviction clocks. Router evicts by char budget on a timer; the worker evicts its real radix tree by token budget whenever decode needs memory (
radix_cache.py:537). A worker under memory pressure can dump exactly the prefixes the router still credits it with — the router will happily keep routing “hits” that are now misses. Nothing reports worker evictions back (contrast Dynamo, below). - Worker restarts / cache flushes.
/flush_cacheexists on both layers, but a worker that restarts behind a stable URL comes back cold while the router’s tree still says hot. - Multi-tenant ambiguity. When several workers legitimately share a prefix,
prefix_match_with_countsreturns the cached/arbitrarylast_tenant(tree.rs:568-599), not the best-loaded owner. - Multi-router deployments. Each router instance learns only from its own traffic. There’s a gossip “mesh” to sync tree ops (
cache_aware.rs:452-463), but the receive path “is not yet wired” (:269-272), so HA router pairs today have divergent trees.
None of these break correctness — worst case is a wasted prefill, the same graceful degradation your sketch had. That’s the design’s actual thesis: affinity is a hint, load is the guardrail, health is the gate.
PD mode and the gRPC pipeline (brief)
In prefill/decode disaggregation the router picks a pair: select_pd_pair (routers/http/pd_router.rs:793-872) runs independent policies over the prefill pool and decode pool (--prefill-policy cache_aware --decode-policy power_of_two is the canonical combo), then injects bootstrap_host/bootstrap_port/bootstrap_room into the JSON body (:228-276) so the decode worker can pull KV from the prefill worker. The gRPC router (routers/grpc/) goes further: it tokenizes in Rust, runs reasoning/tool-call parsing in Rust, and streams token IDs to SRT workers over the proto/ contract — the “industry-first gRPC pipeline” pitch in sgl-model-gateway/README.md. That’s the path where routing can use real token IDs instead of chars.
4.2 RadixAttention: the real prefix cache inside each worker
python/sglang/srt/mem_cache/radix_cache.py is the ground truth the router approximates. Same data structure, different payload — here tree nodes map token-ID keys to KV-cache page indices:
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None
self.lock_ref = 0
self.last_access_time = time.monotonic()
— radix_cache.py:206-211 (TreeNode.__init__, :201). key is a RadixKey (token ids + optional extra_key namespace — LoRA ID, sampling salt — so requests that must not share KV get disjoint subtrees, :337-349); value is a tensor of KV pool slot indices; lock_ref pins nodes that in-flight requests depend on.
Match (match_prefix, :337; helper :622-646) walks children keyed by the first page of tokens, splitting a node when the match ends mid-segment:
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = access_time
prefix_len = child.key.match(key, page_size=self.page_size)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
— radix_cache.py:629-641. The returned indices become req.prefix_indices — those KV slots are simply reused, which is the whole RadixAttention trick: prefill only computes the un-matched suffix.
Insert happens when requests finish (cache_finished_req, :417) or at each chunked-prefill boundary (cache_unfinished_req, :464 — so a 100k-token prompt becomes shareable chunk by chunk, and the request re-matches its own insertion then re-locks the deeper node, :495-516). _insert_helper (:678) returns how many tokens were already present so the duplicate KV slots get freed back to the allocator (:448-451).
Evict (:537) is called by the scheduler whenever the token allocator can’t satisfy a prefill/decode allocation. It heapifies the evictable leaves under a pluggable strategy and frees leaf-by-leaf, cascading to parents:
leaves = list(self.evictable_leaves)
eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
— radix_cache.py:543-555. Strategies live in mem_cache/evict_policy.py:10-66: LRU (default), LFU, FIFO, MRU, priority-aware, SLRU. Pinning: inc_lock_ref/dec_lock_ref (:566-600) move subtree token counts between evictable_size_ and protected_size_ — the scheduler’s admission math budgets against evictable size only. The RadixAttention layer (python/sglang/srt/layers/radix_attention.py:57) is thin: forward just dispatches q/k/v plus the paged metadata to the active attention backend (:145-153 → get_attn_backend().forward(...)), e.g. layers/attention/flashinfer_backend.py, which is what actually computes attention over the non-contiguous cached pages. Variants you can ignore initially: hiradix_cache.py (hierarchical GPU→CPU→storage tiers), swa_radix_cache.py (sliding-window), mamba_radix_cache.py, radix_cache_cpp.py.
4.3 Scheduler: continuous batching, SGLang-style
python/sglang/srt/managers/scheduler.py is a single-threaded event loop per GPU process. The classic loop is event_loop_normal (:1432): recv requests → build batch → run → process results. The default is event_loop_overlap (:1459), which hides all CPU work (detokenization dispatch, radix-cache updates, next-batch construction) behind the GPU:
# Get the next batch to run
batch = self.get_next_batch_to_run()
self.cur_batch = batch
disable_overlap_for_batch = self.is_disable_overlap_for_batch(batch)
# If we do not need to overlap the current batch with the last batch,
# we can process the last batch immediately.
if disable_overlap_for_batch:
pop_and_process()
# Launch the current batch
if batch:
batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), batch_result))
— scheduler.py:1481-1494. Batch N+1 is scheduled and launched while batch N’s results are still on the GPU; results are popped from result_queue one iteration later (:1499-1501), and sampling of the current batch is deferred until the previous batch’s results are in because grammar masks depend on them (:1506-1509). This is the “zero-overhead batch scheduler” from the v0.4 blog; vLLM v1 later adopted a similar async pipeline, but it’s been SGLang’s default posture for ~two years, and it’s why SGLang’s small-model decode throughput stays GPU-bound.
Batch construction, get_next_batch_to_run (:2426): merge last prefill batch into the running decode batch (:2467-2494), then prefill-first: if get_new_batch_prefill produces a batch it runs, otherwise decode (:2524-2535). Inside prefill admission (_get_new_batch_prefill_raw, :2591):
- The waiting queue is priority-sorted by cache awareness —
self.policy.calc_priority(...)(:2630) with default policylpm(longest prefix match). Inmanagers/schedule_policy.py, every waiting request getsmatch_prefixrun against the radix tree, then:
waiting_queue.sort(
key=lambda r: (
-r.num_matched_prefix_tokens
if r.rid not in temporary_deprioritized
else float("inf")
)
)
— schedule_policy.py:300-306. Requests that reuse the most cache go first (maximizing hit rate and minimizing prefill compute); over 128 queued it degrades to FCFS to cap sort cost (:223-227). The temporary_deprioritized set is in-batch prefix caching (:247-293): if several queued requests share a prefix that isn’t cached yet, schedule one and hold the rest so they hit the cache the first one populates — fleet-level dedup at the scheduler.
- A
PrefillAdder(schedule_policy.py:425, constructed atscheduler.py:2647-2663) admits requests under token budgets (max_prefill_tokens, remaining KV,new_token_ratio— an adaptive estimate of decode growth), chunking oversized prompts (chunked prefill,add_chunked_req) and optionally mixing chunked prefill with decode in one forward (is_mixed_chunk). - Under decode OOM,
update_running_batch(:2864) retracts:batch.retract_decode(...)(:2892) pops the youngest requests out of the running batch back into the waiting queue (their KV freed-or-cached; they later resume by re-matching the radix tree) and tightensnew_token_ratio— admission control learns from the event. The log line to grep is"KV cache pool is full. Retract requests."(:2925-2937).
For multi-GPU: TP ranks each run a scheduler process in lockstep; --dp-size adds a data_parallel_controller.py dispatching to replicas; PD disaggregation (srt/disaggregation/) splits prefill and decode into separate servers with KV transfer (mooncake/nixl backends) — which is exactly what the router’s PD mode fronts.
4.4 Structured output (brief)
Grammar backends live in python/sglang/srt/constrained/; default is xgrammar (server_args.py:3357-3359), alternatives llguidance and outlines. Compilation is async so it never stalls the GPU loop: requests needing a grammar park in GrammarManager.grammar_queue (constrained/grammar_manager.py:29) with a compilation-cache lookup (:111), and the scheduler pulls ready ones back into the waiting queue at the top of every prefill pass (scheduler.py:2594-2598). At decode time each request’s matcher fills a per-row bitmask which is applied to logits on-GPU:
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask.to(device, non_blocking=True)
def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
if logits.device.type in {"cuda", "xpu", "musa"}:
if _is_hip:
apply_token_bitmask_inplace_cuda(logits, vocab_mask)
else:
apply_token_bitmask_inplace_triton(logits, vocab_mask)
— constrained/xgrammar_backend.py:105-117. You already know grammar bitmasks; the SGLang-specific bits are the async compile queue, the reasoner_grammar_backend.py wrapper (mask only after </think>), and the interaction with overlap scheduling (scheduler.py:1537-1546 — spec+grammar forces overlap off for that batch).
4.5 Streaming / detokenization egress
Egress is its own process so UTF-8 assembly and stop-string trimming never block scheduling. DetokenizerManager holds per-request DecodeStatus offsets (managers/detokenizer_manager.py:61-86, LRU-capped at 65536 states, :58) and does incremental decode with explicit handling of byte-incomplete codepoints:
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reasons[i] is None:
# Streaming. Invariant: sent_offset >= decoded_text_len. The
# gap (`pending`) is "printable but uncommitted" text emitted
# in a prior "�" recovery step; we skip it from this step's
# emission so we don't double-send.
pending = s.sent_offset - s.decoded_text_len
if new_text and not new_text.endswith("�"):
# Clean text: commit to decoded_text and advance offsets.
s.append_decoded_text(new_text)
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
s.sent_offset = s.decoded_text_len
output_strs.append(new_text[pending:] if pending else new_text)
— detokenizer_manager.py:347-360. The “surrogate window” re-decodes a small token suffix each step so tokenizer merge effects stay correct without re-decoding the whole sequence. Downstream, TokenizerManager matches outputs to waiting request states and the FastAPI layer streams SSE; per-request cached_tokens (your radix-hit observability) rides along in BatchStrOutput (managers/io_struct.py:1126). Note the symmetry with the gateway: in gRPC mode this entire egress stage is reimplemented in Rust (sgl-model-gateway/src/routers/grpc/).
5. Suggested reading path
Router first — it’s self-contained, idiomatic Rust, and small enough to read fully:
sgl-model-gateway/src/policies/cache_aware.rs:1-60— the design doc is the file header. Thenselect_worker(:376-505).sgl-model-gateway/src/policies/tree.rs—Node/Tree(:229-250),insert(:362),prefix_match_with_counts(:531),evict_tenant_by_size(:718). Compare against the prefix tree you sketched.sgl-model-gateway/src/policies/mod.rs— the policy trait + config; then skimpower_of_two.rsandprefix_hash.rsas contrasting policies.sgl-model-gateway/src/core/worker.rs— health hysteresis (:703-752),WorkerLoadGuard(:1109-1140); skimcircuit_breaker.rs,retry.rs.sgl-model-gateway/src/routers/http/router.rs— request flow end-to-end (route_typed_request,:194); thenpd_router.rs:793for pair selection.python/sglang/srt/mem_cache/radix_cache.py— the real tree (whole file, 802 lines), withevict_policy.pyalongside.python/sglang/srt/managers/schedule_policy.py— LPM sorting + in-batch prefix caching (:133-306), thenPrefillAdder(:425+).python/sglang/srt/managers/scheduler.py—event_loop_overlap(:1459),get_next_batch_to_run(:2426), retraction (:2864-2937). Skimschedule_batch.pyforReqlifecycle.python/sglang/srt/managers/detokenizer_manager.py— egress (:260-383).- Optional depth:
constrained/grammar_manager.py,layers/attention/flashinfer_backend.py,disaggregation/.
6. Connections to your other study repos
- vllm — Same problem, different prefix-cache shape: vLLM v1 prefix caching hashes full blocks (block-granular hits) where SGLang’s radix tree gives token-granular hits and drives scheduling order (LPM) off the tree. vLLM preempts by recompute/swap; SGLang retracts with adaptive
new_token_ratio(scheduler.py:2892). Overlap scheduling is default in SGLang, async-scheduling arrived later in vLLM. Reading both schedulers back-to-back is the single highest-leverage comparison. - nano-vllm — Its ~1200 lines map 1:1 onto
scheduler.py’s skeleton (event_loop_normal→ step; block manager →token_to_kv_pool_allocator). Use it as the decoder ring: every nano concept exists here, buried under production features. - dynamo — The other answer to your routing question. Dynamo’s KV-aware router consumes KV events published by workers (block stored/evicted) into a global indexer, so its view of cache state is near ground truth, plus it factors in active prefill load; cost = worker instrumentation, event bus, and an indexer in the hot path. sgl-router’s tree is inferred from routing history only — zero worker cooperation, works with any backend speaking HTTP, but suffers the staleness modes in §4.1. That’s the central trade: telemetry-driven vs inference-driven affinity. (Dynamo can also front SGLang workers; SGLang emits KV events for that integration —
mem_cache/events.pyand the_record_store_event/_record_remove_eventcalls inradix_cache.py:561,729.) - llm-d — Kubernetes-native, vLLM-first; its inference scheduler scores pods with pluggable scorers (prefix-cache scorer, load scorer) inside the EPP. Its prefix scorer is conceptually sgl-router’s tree rebuilt as a gateway plugin. Compare its scorer-weighting approach with sgl-router’s hard threshold switch.
- gateway-api-inference-extension — Your Envoy home turf: Envoy stays the data plane, and an ext-proc “endpoint picker” (EPP) sets the upstream via header mutation; InferencePool/InferenceModel CRDs make it portable across GKE/Istio/etc. sgl-router instead is the data plane: a standalone Rust proxy owning retries, CBs, rate limiting, SSE, even tokenization — no Envoy anywhere. Trade-offs you’ll recognize immediately: GIE composes with existing mesh/policy infra and keeps LB logic out-of-process (ext-proc RTT per request); sgl-router is one static binary with policy in-process (sub-ms decisions) but reinvents the proxy layer (look at
src/core/retry.rs,src/middleware.rsand judge the reinvention quality yourself). - xgrammar — It’s the default
--grammar-backendhere;constrained/xgrammar_backend.pyis a thin adapter over itsGrammarMatcher+ token bitmask kernels. SGLang co-evolved with xgrammar (same LMSYS orbit), and the jump-forward decoding idea originated in the SGLang paper. - flashinfer — SGLang’s premier NVIDIA attention backend (
layers/attention/flashinfer_backend.py): RadixAttention’s paged, ragged batches are exactly the workload flashinfer’s plan/run API was built for. Theattention_registry.pyshows the full backend zoo and how prefill/decode can use different backends.
7. Tinkering on one RTX 5080 (16 GB)
Caveats first: the runtime is Linux-targeted — on your Windows box run it inside WSL2 (CUDA passthrough) or the lmsysorg/sglang Docker image. The 5080 is Blackwell sm_120: use a recent SGLang + matching CUDA 12.8+ torch wheels (the repo carries sm120-specific kernels, e.g. python/sglang/srt/layers/attention/flash_mla_sm120.py; if an attention backend complains, --attention-backend triton is the safe fallback).
Single worker sanity check (a ~1B model leaves plenty of room in 16 GB):
python -m sglang.launch_server \
--model-path Qwen/Qwen3-0.6B \
--port 30001 --mem-fraction-static 0.35 \
--enable-metrics
curl http://127.0.0.1:30001/generate \
-d '{"text": "The capital of France is", "sampling_params": {"max_new_tokens": 32}}'
mem_fraction_static (server_args.py:434) caps weights+KV pool as a fraction of total VRAM, so two workers fit on one GPU. Watch the startup log for KV Cache is allocated and the per-prefill log line — this is your ground-truth cache observability (managers/scheduler_components/metrics_reporter.py:526-535):
Prefill batch, #new-seq: 1, #new-token: 6, #cached-token: 0, ...
Two workers + router on one GPU. Don’t use sglang_router.launch_server co-launch for this — it assumes one GPU per DP rank (bindings/python/src/sglang_router/launch_server.py:92 sets base_gpu_id = dp_id * tp_size). Launch manually:
# terminal 1 and 2 — two workers, same GPU, ~35% VRAM each
python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B \
--port 30001 --mem-fraction-static 0.35 --enable-metrics
python -m sglang.launch_server --model-path Qwen/Qwen3-0.6B \
--port 30002 --mem-fraction-static 0.35 --enable-metrics
# terminal 3 — the router (pip install sglang-router, or cargo build in sgl-model-gateway/)
python -m sglang_router.launch_router \
--worker-urls http://127.0.0.1:30001 http://127.0.0.1:30002 \
--policy cache_aware \
--cache-threshold 0.3 --balance-abs-threshold 64 --balance-rel-threshold 1.5 \
--eviction-interval 120
Router listens on :30000, Prometheus metrics on :29000 (src/main.rs:270).
An experiment that makes cache-aware routing visible. Build M distinct “tenants” (long shared system prompts) and fire interleaved requests:
python -m sglang.bench_serving --backend sglang \
--host 127.0.0.1 --port 30000 \
--dataset-name generated-shared-prefix \
--gsp-num-groups 8 --gsp-prompts-per-group 32 \
--gsp-system-prompt-len 2048 --gsp-question-len 128 --gsp-output-len 64
(bench_serving.py:2027,2369-2398.) With cache_aware, each prompt group converges onto one worker, so repeats of a group’s 2048-token preamble hit that worker’s radix cache. What to watch:
- Per-worker cache hit rate —
sglang:cache_hit_rategauge on each worker’s/metrics(observability/metrics_collector.py:291-296), or eyeball#cached-tokenvs#new-tokenin the worker logs. Cache-aware should hold this high (~90%+ on repeat traffic); now rerun the bench with--policy round_robinon the router and watch it crater toward ~50% with 2 workers (each group’s prefix gets prefilled everywhere). - Router selection counters — on
:29000/metrics:smg_worker_selection_total{policy="cache_aware",...}per worker (src/observability/metrics.rs:237) shows the affinity skew;smg_worker_requests_active(:225) shows in-flight balance;smg_worker_health(:229) for the health gate. - End-to-end TTFT — bench_serving’s mean TTFT under cache_aware vs round_robin is the headline number (cached prefill skips ~2048 of ~2176 prompt tokens per request).
- The threshold switch in action — rerun with
--balance-abs-threshold 2 --balance-rel-threshold 1.1and high concurrency: the imbalance detector now trips constantly, the router degrades to shortest-queue, selections even out, and the hit rate drops. That knob-feel — affinity vs balance as a tunable dial — is the entire thesis of the design, and you can demo it with two 0.6B workers on your one card.
To inspect the router’s mental model directly, GET http://127.0.0.1:30000/v1/loads returns each worker’s self-reported token load (src/server.rs:607-612), and the router’s debug logs (RUST_LOG=debug) print load-balancing trigger events with per-worker loads (policies/cache_aware.rs:318-326). For staleness experiment #3 from §4.1: hit one worker’s /flush_cache, keep the traffic running, and watch its #cached-token go to zero while the router keeps sending the same groups there — the approximation never notices.