XGrammar
A study guide to the XGrammar repo checked out at materials/xgrammar/xgrammar (MLC’s structured-generation engine, v0.2.2 — i.e. “XGrammar-2”). Written for someone who already understands the constrained-decoding loop and wants to know where each piece lives in the code.
1. What it is and where it sits
XGrammar is the CPU-side brain of structured output for most serving stacks. It takes a structure spec (JSON Schema, regex, EBNF, or “structural tags” for tool calling), compiles it against a specific tokenizer vocabulary, and then — per request, per decode step — answers one question: which token IDs are legal next? The answer comes back as a packed bitmask that the engine applies to logits on the GPU before sampling.
It is the default structured-output backend in vLLM, SGLang, TensorRT-LLM, and MLC-LLM. The library itself is a C++17 core (cpp/) with thin Python bindings (python/xgrammar/, bound via apache-tvm-ffi), plus a handful of GPU kernels for the one step that touches the GPU. Everything else — compilation, state tracking, mask filling — runs on the host CPU.
One historical note that matters for reading the code: XGrammar v1 (the MLSys ’25 paper) used a persistent-stack pushdown automaton. XGrammar-2 (this code, the 2026 paper) replaced the PDA with an Earley parser over per-rule FSMs. Several docstrings still say “pushdown automaton” (include/xgrammar/matcher.h:43, python/xgrammar/matcher.py:205) — the actual engine is cpp/earley_parser.cc. Your PDA mental model still works: the Earley state sets are the nondeterministic stack configurations, with rule_start_pos playing the role of the stack back-pointer.
2. Why you care
You already derived the loop. Here is the same loop, named with this codebase’s vocabulary:
| Your concept | XGrammar name | Where |
|---|---|---|
| CPU-side automaton holding parse state | GrammarMatcher wrapping an EarleyParser | cpp/grammar_matcher.cc, cpp/earley_parser.cc |
| Per-step token bitmask emitted by CPU | FillNextTokenBitmask | cpp/grammar_matcher.cc:809 |
| GPU applies mask to logits before sampling | apply_token_bitmask_inplace (Triton/CUDA/CPU) | python/xgrammar/kernels/ |
| Sampled token returns to advance the automaton | AcceptToken | cpp/grammar_matcher.cc:610 |
| Context-independent vs context-dependent token split | AdaptiveTokenMask: accepted/rejected (precomputed) vs uncertain (runtime) | cpp/compiled_grammar_impl.h:40 |
| Precompute-vs-on-the-fly tradeoff | GrammarCompiler + per-ParserState mask cache | cpp/grammar_compiler.cc:1048 |
| Tokens straddling grammar boundaries | the “uncertain” class + lookahead assertions | cpp/grammar_compiler.cc:641 |
| Jump-ahead / forced tokens | FindJumpForwardString | cpp/grammar_matcher.cc:964 |
| Rollback for speculative decoding | Rollback, TraverseDraftTree | cpp/grammar_matcher.cc:1023, :55 |
The interesting engineering is all in two places: how much of the mask can be precomputed per automaton state (compile time), and how cheaply the rest can be resolved per step (runtime). Everything below maps those two.
3. Architecture map
xgrammar/
├── include/xgrammar/ # public C++ API (5 headers: grammar, compiler, matcher, tokenizer_info)
├── cpp/ # the core
│ ├── grammar_parser.cc # EBNF text -> grammar IR
│ ├── json_schema_converter.cc # JSON Schema -> EBNF string (3200 lines; the biggest file)
│ ├── regex_converter.cc # regex -> EBNF
│ ├── structural_tag.cc # tool-call "structural tags" -> grammar (XGrammar-2 agentic feature)
│ ├── grammar_functor.cc # grammar passes: normalize, inline, dead-rule elim, build per-rule FSMs (GrammarOptimizer at grammar_functor.h:409)
│ ├── fsm.cc / fsm_builder.cc # FSM data structure; regex->FSM and trie->FSM builders
│ ├── grammar_compiler.cc # adaptive token-mask precompute + grammar/rule caches
│ ├── compiled_grammar_impl.h # CompiledGrammar = grammar + tokenizer + mask cache; AdaptiveTokenMask struct
│ ├── earley_parser.{h,cc} # the runtime automaton (scan/predict/complete)
│ ├── grammar_matcher.cc # GrammarMatcher: AcceptToken, FillNextTokenBitmask, jump-forward, rollback
│ ├── tokenizer_info.cc # vocab decoding (byte-fallback/byte-level), sorted vocab, prefix-trie ranges
│ └── support/ # dynamic_bitset.h (the mask), thread_safe_cache.h (LRU), thread_pool.h
├── python/xgrammar/ # bindings mirror the C++ classes 1:1
│ ├── grammar.py, compiler.py, matcher.py, tokenizer_info.py
│ ├── kernels/ # apply_token_bitmask_inplace_{triton,cuda,cpu,torch,torch_compile,mlx}
│ └── testing.py # debug helpers (_get_masked_tokens_from_bitmask, etc.)
├── docs/tutorials/ # engine_integration.md is the integration contract
├── examples/benchmark/ # bench_grammar_compile_mask_gen.py — compile + mask-fill latency
└── tests/python/ # extensive; test_grammar_matcher_*.py show the API in anger
4. Core mechanisms
4.1 Grammar compilation: JSON Schema → EBNF → IR → FSMs
Every input format funnels into EBNF text, then into a grammar IR, then per-rule FSMs:
// cpp/grammar.cc:43
Grammar Grammar::FromJSONSchema(
const std::string& schema, bool any_whitespace, std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators, bool strict_mode,
std::optional<int> max_whitespace_cnt, bool print_converted_ebnf
) {
auto ebnf_string =
JSONSchemaToEBNF(schema, any_whitespace, indent, separators, strict_mode, max_whitespace_cnt);
...
return FromEBNF(ebnf_string);
}
JSONSchemaToEBNF (cpp/json_schema_converter.cc:3078) walks the schema and emits grammar rules. The IR node types are in cpp/grammar_impl.h:108 — kByteString, kCharacterClass(Star), kRuleRef, kSequence, kChoices, kTagDispatch, kRepeat, plus the XGrammar-2 token-level edges kToken/kExcludeToken (grammar positions that match whole token IDs, not bytes — used for tool-call dispatch). GrammarOptimizer::Apply (cpp/grammar_functor.h:409) then normalizes, inlines, and builds one FSM per rule (per_rule_fsms).
Where the compilation cost lives. MultiThreadCompileGrammar (cpp/grammar_compiler.cc:1048) enumerates every reachable, scanable FSM state of every rule and precomputes a token mask for each, fanned out over a thread pool:
// cpp/grammar_compiler.cc:1109
for (int32_t rule_id = 0; rule_id < ...NumRules(); ++rule_id) {
auto rule = compiled_grammar_impl->grammar->GetRule(rule_id);
const auto& rule_fsm = compiled_grammar_impl->grammar->per_rule_fsms[rule_id];
auto cur_stack_element =
ParserState(rule_id, rule.body_expr_id, 0, ParserState::kNoPrevInputPos, 0);
std::unordered_set<int> reachable_states;
rule_fsm->GetFsm().GetReachableStates(&reachable_states);
for (int i : reachable_states) {
cur_stack_element.element_id = i;
if (!rule_fsm->GetFsm().IsScanableState(i)) continue;
add_task_adaptive_token_mask(cur_stack_element, rule_id == root_rule_id);
}
}
Each task walks a big chunk of the ~100k-token vocabulary against the parser — this is the O(states × vocab) precompute you predicted. It is amortized by two caches:
- Grammar-level LRU cache —
ThreadSafeLRUCache<UnionKey, CompiledGrammar, ...> grammar_level_cache_(cpp/grammar_compiler.cc:1400, impl incpp/support/thread_safe_cache.h), keyed by the schema/EBNF/regex string plus options (cpp/grammar_compiler.cc:1452). Same schema twice → free. - Rule-level “crossing cache” (XGrammar-2) — per-rule FSMs are structurally hashed (
GrammarFSMHasher,cpp/grammar_functor.h:426) and masks are reused across different grammars that share sub-rules (cpp/grammar_compiler.cc:794-833). Two different tool schemas that both contain a JSON string rule share that rule’s mask work.
CompiledGrammar is serializable (include/xgrammar/compiler.h:39) so engines can persist compilation across processes.
4.2 The adaptive token mask cache — your context-independent/dependent split
This is the data structure to remember. For each automaton state, the vocab is partitioned exactly the way you derived it:
// cpp/compiled_grammar_impl.h:40
struct AdaptiveTokenMask {
enum class StoreType {
// Only store all accepted token indices ... useful when |accepted| < |rejected|.
kAccepted = 0,
// Only store all rejected token indices ... useful when |accepted| > |rejected|.
kRejected = 1,
// Store all accepted token indices in a bitset, when both are large.
kAcceptedBitset = 2
};
StoreType store_type;
static constexpr int USE_BITSET_THRESHOLD = 1000;
std::vector<int32_t> accepted_indices;
std::vector<int32_t> rejected_indices;
DynamicBitset accepted_bitset;
std::vector<int32_t> uncertain_indices;
};
The header comment (cpp/compiled_grammar_impl.h:28-39) states it directly: accepted = decidable by this state alone, rejected = ditto, uncertain = “need the state of the parent ParserStates to determine if acceptable.” Accepted+rejected are your context-independent tokens; uncertain are context-dependent. StoreType is the “adaptive” part — store whichever side is smaller, fall back to a bitset above 1000 entries.
The compiled grammar is literally a hash map from automaton state to this struct (cpp/compiled_grammar_impl.h:119):
std::unordered_map<ParserState, AdaptiveTokenMask, StateHashForCache> adaptive_token_mask_cache;
and the cache key deliberately ignores the input-position field, which is what makes the precompute per-automaton-state rather than per-parse-configuration:
// cpp/earley_parser.h:154
class StateHashForCache {
public:
size_t operator()(const ParserState& state) const {
return HashCombine(state.rule_id, state.sequence_id, state.element_id, state.sub_element_id);
}
};
The classification itself happens in GetAdaptiveTokenMask (cpp/grammar_compiler.cc:780) → GetTokenMaskWithFirstCharacterCheck. For each token (in sorted-vocab order), feed its bytes to a parser seeded with only this rule’s state and no parent context:
// cpp/grammar_compiler.cc:637 (inside the per-token loop)
bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
if (accepted) {
tmp_accepted_indices_.push_back(i);
} else if (can_reach_end && prev_matched_size > 0) {
auto [lookahead_accepted, lookahead_completed] =
IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_);
if ((!is_root_rule) && lookahead_accepted) {
...
tmp_uncertain_indices_.push_back(i);
...
}
} else {
tmp_rejected_indices_.push_back(i);
}
Read that middle branch carefully — it is your “token straddles a grammar boundary” case. The token’s bytes were not fully consumed by this rule, but at some byte the rule could have completed (can_reach_end): whether the remaining bytes are legal depends on whatever rule the parent stack continues with. At compile time that’s unknowable → uncertain. Two refinements cut the uncertain set: for the root rule there is no parent, so nothing is uncertain (is_root_rule), and XGrammar-2’s lookahead assertions (the rule’s right context, when statically known) let some straddling tokens be resolved to accepted/rejected at compile time after all (IsTokenPassLookaheadAssertion, cpp/grammar_compiler.cc:312).
4.3 GrammarMatcher: runtime state and the two core calls
The runtime automaton state is a set of ParserStates (cpp/earley_parser.h:41 — rule_id, sequence_id, element_id within the rule FSM, rule_start_pos back-pointer, sub_element_id for mid-UTF-8 positions). The parser keeps full history (scanable_state_history_, a Compact2DArray) so any number of tokens can be rolled back — that’s what makes jump-forward and speculative decoding cheap. Advancing by one byte is classic Earley scan/predict/complete:
// cpp/earley_parser.cc:257
bool EarleyParser::Advance(const uint8_t ch, bool debug_print) {
...
const auto& latest_states = scanable_state_history_[scanable_state_history_.size() - 1];
for (const auto& state : latest_states) Scan(state, ch);
if (tmp_process_state_queue_.empty() && tmp_states_to_be_added_.empty()) return false;
while (!tmp_process_state_queue_.empty()) {
const auto state = std::move(tmp_process_state_queue_.front());
tmp_process_state_queue_.pop();
auto [scanable, completable] = Predict(state, debug_print);
if (completable) Complete(state, debug_print);
if (scanable) tmp_states_to_be_added_.push_back(state);
}
is_completed_.push_back(tmp_accept_stop_token_);
scanable_state_history_.PushBack(tmp_states_to_be_added_);
return true;
}
Accepting a sampled token (GrammarMatcher::Impl::AcceptToken, cpp/grammar_matcher.cc:610) tries two paths and merges: an atomic path over kToken/kExcludeToken edges (the token matched as a unit), and the byte path (the token’s decoded bytes fed through Advance one at a time, cpp/grammar_matcher.cc:668-676). Stop tokens are only accepted when IsCompleted() (:634-641); rejection pops any partially-advanced bytes so the state is untouched.
Filling the mask (FillNextTokenBitmask, cpp/grammar_matcher.cc:809) is where compile-time and runtime meet. For every live Earley state, look up its precomputed AdaptiveTokenMask; OR together the accepted sets, AND together the rejected sets (nondeterminism = union of possibilities):
// cpp/grammar_matcher.cc:842
for (const auto& state : latest_states) {
auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
const auto& adaptive_token_mask = adaptive_token_mask_it->second;
...
if (adaptive_token_mask.store_type == StoreType::kAcceptedBitset) {
tmp_accepted_bitset_ |= adaptive_token_mask.accepted_bitset;
} else if (adaptive_token_mask.store_type == StoreType::kAccepted) {
for (auto idx : adaptive_token_mask.accepted_indices)
tmp_accepted_bitset_.Set(sorted_decoded_vocab[idx].first, true);
}
}
Then the uncertain tokens are resolved on the fly by actually stepping the live parser through their bytes (cpp/grammar_matcher.cc:880-940) — this is the only per-step work proportional to grammar context, and it’s bounded by the uncertain count, not the vocab size. Finally SetTokenBitmask (cpp/grammar_matcher.cc:1035) writes the bitset, turning stop tokens on only if IsCompleted() — that’s how “end of structure → only EOS legal” falls out. The function returns need_apply=false when the mask is all-ones, letting engines skip the GPU apply entirely (python/xgrammar/matcher.py:334).
4.4 Tokens vs grammar alignment: bytes, sorted vocab, trie ranges
The grammar operates on bytes, and the vocab is normalized to bytes up front. TokenizerInfo decodes each tokenizer entry through ByteFallbackDecoder (<0x1B> → raw byte 0x1B) or ByteLevelDecoder (inverse of GPT-2 bytes-to-unicode) at cpp/tokenizer_info.cc:47 and :83, so a token is just a byte string and partial UTF-8 inside a token is handled by sub_element_id/partial_codepoint in ParserState.
Matching 100k token strings per state is made tractable by sorting the decoded vocab lexicographically and precomputing, for each token, the half-open range of tokens that have it as a prefix — a flattened prefix trie:
// cpp/tokenizer_info.cc:306
// The value means: the subtree is [i, trie_subtree_nodes_range[i]).
trie_subtree_nodes_range_.resize(sorted_decoded_vocab_.size(), 0);
std::stack<std::pair<std::string, int32_t>> prefix_stack;
for (size_t i = 0; i < sorted_decoded_vocab_.size(); ++i) {
const auto& token = sorted_decoded_vocab_[i].second;
while ((!prefix_stack.empty()) && (token.find(prefix_stack.top().first) == std::string::npos)) {
trie_subtree_nodes_range_[top_pair.second] = i;
prefix_stack.pop();
}
prefix_stack.push({token, i});
}
Both the compile-time classifier and the runtime uncertain-resolver exploit this: consecutive sorted tokens share prefixes, so the parser only advances/pops the differing suffix (the std::mismatch LCP logic at cpp/grammar_compiler.cc:596 and cpp/grammar_matcher.cc:901), and when a prefix is rejected the entire subtree range is skipped (subtree_range[cur_token_idx], cpp/grammar_matcher.cc:888). Note the masks store sorted-vocab indices, translated to token IDs only when writing the final bitmask.
4.5 Mask application on GPU
The bitmask layout is ceil(vocab_size/32) int32 words per row (GetBitmaskSize → DynamicBitset::GetBufferSize, cpp/grammar_matcher.cc:36; shape helper at python/xgrammar/matcher.py:19). Bit=1 means allowed; bit=0 means set the logit to −inf. For a 128k vocab that’s 16 KB per request — small enough that engines copy it CPU→GPU every step. The default GPU path is Triton:
# python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py:64
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)
tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
indices lets a mixed batch apply masks only to the structured requests. The CUDA twin lives in python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu (BITS_PER_BLOCK = 32 at :34, vectorized LogitsBitmaskKernel at :64, JIT-compiled through torch.utils.cpp_extension). Backend dispatch — cpu / triton / cuda / torch_compile / torch_native / mlx — is in python/xgrammar/matcher.py:58-188. There is also a pure-C++ CPU apply (ApplyTokenBitmaskInplaceCPU, include/xgrammar/matcher.h:32).
4.6 Jump-forward / forced tokens
When the grammar forces output (e.g. {"name": " after {), the engine shouldn’t pay a forward pass per forced character. FindJumpForwardString walks the live states’ FSM edges and extends the result while every state agrees on exactly one possible next byte:
// cpp/grammar_matcher.cc:984 (inside the while loop)
int next_char = -1;
for (const auto& state : states) {
const auto& fsm = grammar_->per_rule_fsms[state.rule_id].value();
const auto& current_edges = fsm.GetFsm().GetFsm().GetEdges(state.element_id);
for (const auto& edge : current_edges) {
if (!edge.IsCharRange()) continue;
if (edge.min != edge.max) { can_find_next_char = false; break; }
if (next_char == -1) next_char = edge.min;
else if (next_char != edge.min) { can_find_next_char = false; break; }
}
}
// 2. If found, accept the char and iterate; finally PopLastStates(num_accepted_chars)
It advances the parser to find the next forced byte, then rolls everything back (PopLastStates) — state is unchanged. The engine retokenizes the returned string and force-appends; that retokenization can shift token boundaries, which is why Rollback (cpp/grammar_matcher.cc:1023) exists. SGLang uses this (try_jump_forward); vLLM v1 currently only references it in a comment (backend_xgrammar.py:137). The related single-token shortcut _IsSingleTokenBitmask (include/xgrammar/matcher.h:30) lets an engine skip sampling when the mask allows exactly one token.
4.7 Speculative decoding hook
TraverseDraftTree (include/xgrammar/matcher.h:130, impl cpp/grammar_matcher.cc:55-110) DFS-walks a draft token tree, calling AcceptToken/FillNextTokenBitmask/Rollback at each node to produce one mask per draft position — the grammar-side counterpart of tree-based spec decode. Unlimited rollback is what makes this and jump-forward safe (python/xgrammar/matcher.py:230: “the new Earley parser … allows unlimited rollback”).
4.8 Engine integration surface
The contract an engine implements per decode step (see docs/tutorials/engine_integration.md):
# per model: tokenizer_info = xgr.TokenizerInfo.from_huggingface(tok, vocab_size=config.vocab_size)
# compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) # holds the LRU cache
# per request: matcher = xgr.GrammarMatcher(compiler.compile_json_schema(schema))
# per step:
matcher.fill_next_token_bitmask(token_bitmask, i) # CPU, one row per request
xgr.apply_token_bitmask_inplace(logits, token_bitmask.to(logits.device)) # GPU
# after sampling:
assert matcher.accept_token(next_token_id)
BatchGrammarMatcher (cpp/grammar_matcher.cc:1117) parallelizes mask fills across requests with an internal thread pool — fill time is per-request CPU work on the critical path between forward pass and sampling, so engines overlap or thread it.
In your local checkouts:
- vLLM (
materials/vllm/vllm): backend atvllm/v1/structured_output/backend_xgrammar.py(fill_bitmask→matcher.fill_next_token_bitmaskat:191,allocate_token_bitmaskat:124). Orchestration invllm/v1/structured_output/__init__.py— aStructuredOutputManagerthat runs fills on aThreadPoolExecutoronce the structured batch exceeds 128 requests (:62-64). The bitmask travels scheduler→worker, and the GPU apply is vLLM’s own Triton kernel,_apply_grammar_bitmask_kernelinvllm/v1/worker/gpu/structured_outputs.py:86— same bit layout, reimplemented engine-side. - SGLang (
materials/sglang/sglang): backend atpython/sglang/srt/constrained/xgrammar_backend.py—accept_token(:79),fill_vocab_mask(:105),apply_vocab_maskdispatching to xgrammar’s kernels (:112), and jump-forward wired up viatry_jump_forward(:144). Selected among backends inbase_grammar_backend.py/grammar_manager.pyin the same directory.
5. Suggested reading path
docs/tutorials/engine_integration.md— the whole loop in 60 lines of Python; confirms your model end-to-end.include/xgrammar/matcher.h+include/xgrammar/compiler.h— the entire public API surface, ~370 lines.cpp/compiled_grammar_impl.h—AdaptiveTokenMaskand the state→mask map; the central data structure.cpp/earley_parser.hthencpp/earley_parser.cc:257(Advance) — the runtime automaton; skim Scan/Predict/Complete.cpp/grammar_matcher.cc—AcceptToken(:610),FillNextTokenBitmask(:809),FindJumpForwardString(:964),SetTokenBitmask(:1035).cpp/grammar_compiler.cc—MultiThreadCompileGrammar(:1048),GetAdaptiveTokenMask(:780), the classification loop (:594-665), and the two cache layers (:1394-1400).cpp/tokenizer_info.cc:266-323— vocab decode, sort, trie ranges; plus the byte decoders at:27-100.python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py— 130 lines; then compare vLLM’s copy invllm/v1/worker/gpu/structured_outputs.py.
Optional ninth: cpp/json_schema_converter.cc (only if you want schema→EBNF details) and tests/python/test_grammar_matcher_basic.py for executable documentation.
6. Connections to your other study repos
- vllm — consumer. Integration tier:
vllm/v1/structured_output/(manager + xgrammar/guidance/outlines backends behindbackend_types.py), GPU apply invllm/v1/worker/gpu/structured_outputs.py. Note the division of labor: xgrammar fills masks in the scheduler process; the worker applies them — the bitmask is part of scheduler→worker IPC. - sglang — consumer.
python/sglang/srt/constrained/mirrors the same backend abstraction and additionally exercises jump-forward. SGLang historically drove the jump-forward idea (its Outlines backend hasoutlines_jump_forward.py). - nano-vllm — no structured output at all. Adding it is a clean exercise: (1) hold a
GrammarMatcherper sequence innanovllm/engine/scheduler.py’s sequence state, (2) innanovllm/engine/model_runner.py, fill a bitmask for grammar-bearing sequences andapply_token_bitmask_inplaceon the logits beforenanovllm/layers/sampler.pyruns, (3) callaccept_tokenon the sampled IDs, reset/free on finish. ~100 lines; the subtleties are batching (indices for mixed batches) and not blocking the step on compile (vLLM compiles asynchronously and treats the grammar as “not ready” until done). - flashinfer — sibling, not a dependency. Same architectural slot as xgrammar’s kernels: small device ops bracketing sampling. In a vLLM step the order is xgrammar bitmask apply → flashinfer/torch sampling kernels. No xgrammar code in the flashinfer repo.
- dynamo — mostly unrelated (routing/disaggregation tier). Its only touchpoint: passing
guided_decoding_backend: xgrammar|llguidancethrough to TensorRT-LLM (components/src/dynamo/trtllm/backend_args.py:231). One traffic-relevant nuance: structured requests change per-step cost (CPU mask fill on the critical path, compile spikes on first sight of a schema), which a router could in principle account for — none of these do today. - llm-d, gateway-api-inference-extension — honestly unrelated. Structured output is opaque request payload by the time it reaches the gateway/EPP tier; nothing to find there.
7. Hands-on on your machine (RTX 5080, 16 GB, Windows)
pip install xgrammar ships prebuilt wheels including Windows. Caveat: the Triton dependency is Linux-x86_64-only (pyproject.toml:22), and the cuda backend JIT-compiles via nvcc+MSVC — on native Windows, GPU mask apply falls back to torch_compile. That doesn’t matter: everything interesting here is CPU-only. The GPU never sees the grammar, only 16 KB of bits. (For the Triton path, use WSL2.)
Experiment 1: watch the mask evolve (no GPU)
import time, xgrammar as xgr
from transformers import AutoTokenizer
from xgrammar.testing import _get_masked_tokens_from_bitmask
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
info = xgr.TokenizerInfo.from_huggingface(tok)
compiler = xgr.GrammarCompiler(info, max_threads=8)
schema = '{"type":"object","properties":{"name":{"type":"string"},"port":{"type":"integer"}},"required":["name","port"]}'
t0 = time.perf_counter()
cg = compiler.compile_json_schema(schema)
print(f"compile {(time.perf_counter()-t0)*1e3:.1f} ms, mask cache {cg.memory_size_bytes/2**20:.2f} MiB")
matcher, bitmask = xgr.GrammarMatcher(cg), xgr.allocate_token_bitmask(1, info.vocab_size)
for tid in tok.encode('{"name": "edge-proxy", "port": 8443}', add_special_tokens=False):
t0 = time.perf_counter(); matcher.fill_next_token_bitmask(bitmask)
us = (time.perf_counter()-t0)*1e6
allowed = info.vocab_size - len(_get_masked_tokens_from_bitmask(bitmask, info.vocab_size))
print(f"fill {us:7.1f} us allowed {allowed:6d} jump={matcher.find_jump_forward_string()!r} next={tok.decode([tid])!r}")
assert matcher.accept_token(tid)
Run it twice — the second compile_json_schema returns in microseconds (LRU cache hit). The jump= column shows forced text ({"name": " etc.) shrinking to '' whenever the model has a real choice. Add debug_print=True to fill_next_token_bitmask/accept_token, or call matcher._debug_print_internal_state(), to see the live Earley states.
Experiment 2: mask-fill latency on a gnarly schema
Fill cost scales with live states × uncertain tokens, so stress it: deep nesting, anyOf fans, and regex-pattern strings (format: email/uuid style patterns produce character-class-heavy automata where few tokens are decidable per state). Compare against the trivial schema above; also try compiler.compile_builtin_json_grammar() (free-form JSON — highly recursive) and a pathological EBNF of your own via compiler.compile_grammar(ebnf). Time three things separately: first compile (cold), fill_next_token_bitmask per step along a long valid output (use matcher.accept_string(...) to fast-forward), and BatchGrammarMatcher.batch_fill_next_token_bitmask with 64 forked matchers (matcher.fork()) to see the threaded path engines actually use. The repo’s own harness for exactly this comparison is examples/benchmark/bench_grammar_compile_mask_gen.py (xgrammar vs outlines vs llguidance) and examples/benchmark/bench_apply_token_bitmask_inplace.py for the kernel side.
Experiment 3 (GPU, optional): end-to-end with a real model
The HF integration in python/xgrammar/contrib/hf.py provides a LogitsProcessor; with a 0.5–3B model on the 5080 you can run docs/tutorials/workflow_of_xgrammar.md’s generation loop verbatim and verify the per-step overhead is dominated by the CPU fill, not the mask apply. On native Windows pass backend="torch_compile" to apply_token_bitmask_inplace (or run under WSL2 for Triton).