FlashInfer
A study guide to the kernel library that sits at the very bottom of your inference stack.
Repo: materials/flashinfer/flashinfer (github.com/flashinfer-ai/flashinfer). All paths below
are relative to the repo root; line numbers match this checkout.
1. What it is
FlashInfer is the GPU kernel library and kernel generator that inference engines call for the
operations that dominate a serving step: attention (prefill, decode, MLA, cascade/shared-prefix,
sparse), sampling (top-k/top-p/min-p on device), plus GEMM/MoE, RoPE, and norms. It is not an
engine — it has no scheduler, no server, no model code. It exposes PyTorch-callable ops that are
JIT-compiled per configuration (dtype, head_dim, feature flags) and cached as .so files.
Who consumes it:
- SGLang — its standard attention backend on NVIDIA GPUs (
--attention-backend flashinfer) and its on-GPU sampling kernels. - vLLM — optional attention backend (
VLLM_ATTENTION_BACKEND=FLASHINFER), and the channel through which NVIDIA’s trtllm-gen Blackwell kernels reach vLLM. - Also TensorRT-LLM, TGI, MLC-LLM, LightLLM (see
README.md:231-242).
The project’s own CLAUDE.md is an unusually good architecture document — read it alongside this
guide.
2. Why you care
- It defines the per-step cost model. Every latency assumption upstream — scheduler token budgets, router ITL estimates, autoscaler capacity curves — bottoms out at: decode step time ≈ (weight bytes + batch KV bytes) / HBM bandwidth; prefill time ≈ FLOPs / tensor-core throughput. Those constants are set by the kernels in this repo. When SGLang’s scheduler decides how many tokens fit in a step, it is budgeting against these kernels’ behavior.
- The plan/run split is the reason engines look the way they do. Kernels are “planned” on the CPU (workspace layout, work partitioning) once per batch composition, then “run” per layer with zero host decisions — which is what makes CUDA-graph capture, fixed capture batch sizes, and cheap per-layer dispatch possible. It is a literal control-plane/data-plane separation.
- Cascade attention is the kernel-level payoff of prefix sharing. Prefix caching as you know it deduplicates KV storage. The cascade kernels additionally deduplicate KV bandwidth — the shared prefix is read from HBM once for the whole batch instead of once per request. That’s why prefix-aware routing pays off twice.
3. Architecture map
Four layers, with a hard rule: framework-agnostic kernels never see PyTorch headers.
flashinfer/ Python package: public API, one module per op family
decode.py, prefill.py attention wrappers (the plan/run classes)
cascade.py, sampling.py shared-prefix attention; on-GPU sampling
page.py, mla/, pod.py KV-page append, DeepSeek MLA, fused prefill+decode
jit/ JIT machinery: JitSpec, Jinja codegen, ninja builds
core.py, env.py build/load/cache; ~/.cache/flashinfer layout
attention/modules.py gen_*_module() generators + URI naming
csrc/ TVM-FFI bindings + launchers (PyTorch tensor world)
batch_decode.cu Plan() and Run() entry points for decode
*.jinja type-specialization templates rendered at JIT time
include/flashinfer/ header-only CUDA kernel templates (raw pointers only)
page.cuh paged_kv_t — the page table struct kernels consume
attention/decode.cuh CUDA-core decode kernel (bandwidth-bound GEMV)
attention/prefill.cuh tensor-core prefill kernel (compute-bound GEMM)
attention/scheduler.cuh DecodePlan/PrefillPlan — CPU-side work partitioning
attention/cascade.cuh attention-state merge kernels
sampling.cuh sorting-free top-k/top-p rejection sampling
benchmarks/ per-kernel microbenchmarks (no model weights needed)
flashinfer-cubin/, flashinfer-jit-cache/ optional prebuilt-kernel packages
Bindings use TVM-FFI rather than torch extensions, so the same compiled kernels are usable from
non-PyTorch frameworks. 3rdparty/ vendors CUTLASS and spdlog.
4. Core mechanisms
4.1 The plan/run split — control plane vs data plane
Every batch attention API is a wrapper class with two phases. plan() runs on the CPU once per
batch composition (which requests, which lengths); run() executes per layer — 32-80 times per
step — with no host-side decisions. If you squint, plan is route computation and run is packet
forwarding against an installed FIB.
flashinfer/decode.py:901 (BatchDecodeWithPagedKVCacheWrapper.plan) takes the page-table
metadata as host tensors, because the planner needs to read lengths on the CPU, and calls into
C++ BatchDecodeWithPagedKVCachePlan (csrc/batch_decode.cu:39). The real work is
DecodePlan in include/flashinfer/attention/scheduler.cuh:426:
// include/flashinfer/attention/scheduler.cuh:437
FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages,
new_batch_size, gdy, batch_size, indptr_h, num_qo_heads,
page_size, enable_cuda_graph, stream));
size_t padded_batch_size;
plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.split_kv = split_kv;
padded_batch_size =
(enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size;
plan_info.padded_batch_size = padded_batch_size;
auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] =
DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages);
What the plan actually decides:
- Load balancing. Decode at small batch sizes can’t fill the GPU (one CTA per request per KV
head). The planner binary-searches the smallest KV chunk size such that splitting long
sequences into chunks produces enough CTAs to occupy every SM
(
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch,include/flashinfer/attention/scheduler.cuh:74). Classic bin-packing, done ahead of time on the CPU. - Workspace layout. All scratch (partial outputs for split-KV, index arrays) is carved out of
one user-provided 128 MB workspace buffer by an offset allocator — no per-call
cudaMalloc. - One H2D copy. The whole schedule is staged in pinned host memory and shipped with a single
cudaMemcpyAsync(scheduler.cuh:489).
The plan is returned to Python as an opaque vector of int64 offsets (DecodePlanInfo::ToVector,
scheduler.cuh:391). At run time, csrc/batch_decode.cu:175 rehydrates raw pointers from it:
// csrc/batch_decode.cu:175
params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
if (plan_info.split_kv) {
tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
}
CUDA-graph tie-in: plan() explicitly “cannot be used in CUDA Graph or torch.compile”
(flashinfer/decode.py:1006), while run() is capturable. Under graphs the grid size must be
fixed, so the planner pads the launch to a fixed padded_batch_size and the kernel early-exits
surplus blocks via block_valid_mask (include/flashinfer/attention/decode.cuh:423). This is
why engines capture graphs at a fixed set of batch sizes and bucket requests into them.
4.2 Paged KV-cache attention: how the page table reaches the kernel
The page table is CSR. Three int32 arrays describe the whole batch:
indices (flat list of physical page ids), indptr ([batch+1], request i owns
indices[indptr[i]:indptr[i+1]]), and last_page_len (fill level of each request’s last page).
The engine (SGLang/vLLM) owns page allocation; FlashInfer only consumes this view. On the C++
side it becomes a 30-line struct passed by value into the kernel:
// include/flashinfer/page.cuh:37
template <typename DType, typename IdType>
struct paged_kv_t {
uint_fastdiv page_size;
uint32_t num_heads, head_dim, batch_size;
uint32_t stride_page, stride_n, stride_h;
DType* k_data; // base pointer of the whole page pool
DType* v_data;
IdType* indices; // physical page ids
IdType* indptr; // [batch_size + 1] per-request extent into indices
IdType* last_page_len; // [batch_size] entries used in the final page
IdType* rope_pos_offset;
...
};
Inside the decode kernel, threads translate logical token positions to physical addresses — a
software page-table walk. uint_fastdiv precomputes the magic numbers so the hot loop never does
an integer divide:
// include/flashinfer/attention/decode.cuh:480
uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start;
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}
The decode main loop (decode.cuh:523-585) is a textbook bandwidth machine: a multi-stage
cp_async pipeline streams K tiles and V tiles from HBM into shared memory while the previous
tile is being dotted against the single query vector held in registers:
// include/flashinfer/attention/decode.cuh:523 (elided)
for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) {
cp_async::wait_group<2 * num_stages_smem - 1>(); // wait for oldest K tile
compute_qk<...>(params, variant, batch_idx, k_smem + ..., q_vec, ...); // q . k
cp_async::pred_load<...>(k_smem + ..., paged_kv.k_data + kv_offset[j], ...); // prefetch next K
cp_async::commit_group();
cp_async::wait_group<2 * num_stages_smem - 1>(); // wait for V tile
update_local_state<...>(v_smem + ..., s, stage_idx, st, tx); // softmax-weighted V accumulate
cp_async::pred_load<...>(v_smem + ..., paged_kv.v_data + kv_offset[j], ...); // prefetch next V
cp_async::commit_group();
}
Every KV byte is read exactly once; the arithmetic per byte is a handful of FMAs. That is the
whole “decode is bandwidth-bound” story in code: the kernel’s job is to keep the memory pipes
full, and the planner’s job (4.1) is to make sure there are enough CTAs for that. When a sequence
was split, partial results carry (output, log-sum-exp) and a second kernel merges them — the
same merge operator cascade attention uses (4.4).
4.3 Prefill vs decode kernel families
Two separate kernel templates exist because the two phases sit on opposite ends of the roofline:
- Decode (
include/flashinfer/attention/decode.cuh): one query token per request. q·K is a GEMV — O(1) FLOPs per KV byte — so it runs on plain CUDA cores with vectorized loads, as above. Tensor cores would sit idle. - Prefill (
include/flashinfer/attention/prefill.cuh): many query tokens attend to many keys. q·K is a GEMM — O(qo_len) FLOPs per KV byte — so the kernel is built around 16x16x16 tensor-core fragments (mma= matrix-multiply-accumulate instruction):
// include/flashinfer/attention/prefill.cuh:946
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
if constexpr (std::is_same_v<typename KTraits::DTypeQKAccum, float>) {
if (mma_d == 0) {
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ, MMAMode::kInit>(
s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
} else {
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ>(
s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
}
}
...
}
The GQA twist that blurs the line: with grouped-query attention, 4-8 query heads share each KV
head, so even a “decode” step has a small matrix (group_size x head_dim) per KV head. Pass
use_tensor_cores=True to the decode wrapper and it silently routes through the prefill
kernel family (flashinfer/decode.py:1181-1213 calls get_batch_prefill_module) — same page
table, different execution engine. Mixed chunked-prefill+decode batches get dedicated paths:
POD-attention (csrc/batch_pod.cu) fuses both phases into one kernel, and the newer holistic
persistent scheduler lives in csrc/batch_attention.cu /
include/flashinfer/attention/persistent.cuh.
For your mental model: a serving step’s cost = one compute-bound prefill kernel over the new tokens + one bandwidth-bound decode kernel over all resident KV. Chunked prefill exists upstream precisely to interleave these two regimes without starving either resource.
4.4 Cascade attention: the kernel payoff of prefix sharing
FlashInfer generalizes attention into an attention state (v(I), s(I)) — the weighted value
and log-sum-exp over an index set I — with an associative, commutative merge operator
(docs/tutorials/recursive_attention.rst). Attention over I ∪ J equals merging the states of
I and J, computed in any order, on any device. The merge kernel is just a numerically-stable
softmax renormalization:
// include/flashinfer/attention/cascade.cuh:98 (MergeStateInPlaceKernel)
float s_val = s[pos * num_heads + head_idx];
float s_other_val = s_other[pos * num_heads + head_idx];
float s_max = max(s_val, s_other_val);
s_val = math::ptx_exp2(s_val - s_max);
s_other_val = math::ptx_exp2(s_other_val - s_max);
float scale = s_val / (s_val + s_other_val);
float other_scale = s_other_val / (s_val + s_other_val);
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = scale * v_vec[i] + other_scale * v_other_vec[i];
}
MultiLevelCascadeAttentionWrapper (flashinfer/cascade.py:226) exploits this for shared
prefixes: level 0 treats the entire batch as one query block attending to the shared prefix’s
KV; lower levels handle each request’s unique suffix; then states are merged:
# flashinfer/cascade.py:547
out, lse = self._batch_prefill_wrappers[-1].run(q, paged_kv_cache, return_lse=True)
for wrapper in self._batch_prefill_wrappers[:-1]:
out_i, lse_i = wrapper.run(q, paged_kv_cache, return_lse=True)
merge_state_in_place(out, lse, out_i, lse_i)
return out
Why it’s fast: in the level-0 pass, all batch queries attend the prefix together (it’s matrix-shaped, tensor-core friendly), and the prefix KV is fetched from HBM once for the whole batch instead of once per request. SGLang’s radix tree shares prefix pages (capacity win); cascade shares prefix traffic (bandwidth win). Same split-then-merge machinery also powers split-KV decode (4.2) and, conceptually, any cross-device attention sharding.
4.5 JIT compilation: kernels generated and cached at runtime
The kernel space is combinatorial — dtype(q) x dtype(kv) x dtype(o) x head_dim x position encoding x sliding window x logits soft-cap x backend x GPU arch. AOT-compiling all of it is infeasible, so FlashInfer compiles on first use. Each combination gets a deterministic URI:
# flashinfer/jit/attention/modules.py:78
return (
f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
f"head_dim_qk_{head_dim_qk}_"
f"head_dim_vo_{head_dim_vo}_"
f"posenc_{pos_encoding_mode}_"
f"use_swa_{use_sliding_window}_"
f"use_logits_cap_{use_logits_soft_cap}"
)
A gen_*_module() function (e.g. gen_batch_decode_module,
flashinfer/jit/attention/modules.py:915) renders a Jinja template
(csrc/batch_decode_customize_config.jinja) into a type-specialized config header, copies the
.cu sources, and returns a JitSpec. Building means: emit build.ninja, run nvcc, load the
.so through TVM-FFI — with an AOT short-circuit:
# flashinfer/jit/core.py:307
def build_and_load(self):
if self.is_aot:
return self.load(self.aot_path)
# Guard both build and load with the same lock to avoid race condition
with FileLock(self.lock_path, thread_local=False):
so_path = self.jit_library_path
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
self.build(verbose, need_lock=False)
result = self.load(so_path)
return result
Caching is two-level: @functools.cache in-process (see get_sampling_module,
flashinfer/sampling.py:66) and on disk at ~/.cache/flashinfer/<version>/<archs>/cached_ops/
(flashinfer/jit/env.py:135-150 — version and arch list are part of the path, so upgrades and
new GPUs invalidate cleanly). Recompilation triggers: a never-seen parameter combination, source
file hash change, compiler flag change, arch list change, version change.
Ops note for the layers you run: the first request with a new model shape stalls for seconds-to-
minutes of nvcc — this is why engines warm up before serving traffic, why flashinfer-jit-cache
and flashinfer-cubin wheels exist (prebuilt .sos / downloaded trtllm-gen cubins), and why
FLASHINFER_DISABLE_JIT=1 (flashinfer/jit/core.py:290) is used in locked-down prod images to
fail fast instead of compiling.
4.6 Sampling: selection never leaves the GPU
After attention produces logits, token selection (temperature, top-k, top-p, min-p) runs in a single fused kernel per batch — no sort of the 128k-entry vocab, no host round-trip. The trick is rejection sampling with a shrinking pivot window: sample a token from the full distribution, then check in one block-wide reduction whether the top-p constraint would have admitted it; if not, tighten the admissible probability range and resample. One CTA per batch row; the do/while converges in a few rounds:
// include/flashinfer/sampling.cuh:1070 (TopPSamplingFromProbKernel, inside do { ... })
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1;
}
} while (low < high);
(pivot_0 is the sampled token’s probability, pivot_1 a midpoint — two candidate thresholds
evaluated per pass over the vocab, so each round halves the search interval.) Python entry points:
top_p_sampling_from_probs (flashinfer/sampling.py:976), top_k_top_p_sampling_from_probs
(flashinfer/sampling.py:1579), plus chain_speculative_sampling for draft-token verification in
speculative decoding. Bindings live in csrc/sampling.cu / csrc/renorm.cu. The end-to-end
consequence you already know from the serving side: the only thing crossing PCIe per step is a
handful of int32 token ids.
5. Suggested reading path
README.md+docs/tutorials/kv_layout.rst— vocabulary: NHD/HND layouts, ragged vs paged, indptr conventions. Thendocs/tutorials/recursive_attention.rst(10 minutes, the merge operator).flashinfer/decode.py:712-900—BatchDecodeWithPagedKVCacheWrapper.__init__: workspace buffers, CUDA-graph buffer pinning. This class is the contract engines program against.flashinfer/decode.py:901-1281(plan) and:1352-1660(run) — note what is host metadata vs device tensors, and what gets cached on the wrapper.csrc/batch_decode.cu(200 lines, read in full) — Python tensors becomepaged_kv_t+ raw pointers; plan info becomes workspace offsets.include/flashinfer/attention/scheduler.cuh:60-130and:366-492— the planner. This is the most “your world” file in the repo: occupancy estimation, binary-search load balancing, offset allocation.include/flashinfer/page.cuh:37-210— the page table struct and address translation.include/flashinfer/attention/decode.cuh:380-608— one full bandwidth-bound kernel. Skim for the pipeline structure, not the index arithmetic.include/flashinfer/attention/prefill.cuh— skimcompute_qk(:878) just to see the mma fragments; thenflashinfer/cascade.py:226-560+include/flashinfer/attention/cascade.cuh.flashinfer/sampling.py+include/flashinfer/sampling.cuh:968-1090.flashinfer/jit/core.py+flashinfer/jit/attention/modules.py:915-960, then look inside~/.cache/flashinfer/after running anything.benchmarks/bench_batch_decode.py— 100 lines that tie the whole cost model together.
End-to-end trace to cement it: one wrapper.run(q, kv_cache) decode call —
decode.py:1352 → csrc/batch_decode.cu:83 → BatchDecodeWithPagedKVCacheDispatched
(include/flashinfer/attention/decode.cuh:741) → BatchDecodeWithPagedKVCacheKernel
(decode.cuh:613) → optional VariableLengthMergeStates (cascade.cuh:687).
6. Connections to your other study repos
- sglang — primary consumer.
python/sglang/srt/layers/attention/flashinfer_backend.pywraps these exact wrapper classes; itsinit_forward_metadatais whereplan()gets called each scheduler step, and SGLang’s “capture batch sizes” map to the CUDA-graph padding you saw in 4.1. Its sampler callsflashinfer.sampling.*. SGLang’s radix cache decides what pages are shared; FlashInfer decides how cheaply they’re read. - vllm — optional attention backend (
VLLM_ATTENTION_BACKEND=FLASHINFER); on Blackwell it is also the delivery vehicle for trtllm-gen attention cubins. Comparing vLLM’s FlashInfer backend wrapper with SGLang’s is a good exercise: same plan/run contract, different metadata plumbing. - nano-vllm — uses flash-attn instead of FlashInfer; reading its attention call next to SGLang’s FlashInfer backend shows exactly what the plan/run + page-table contract buys (and costs) versus a single monolithic kernel call.
- dynamo / llm-d / gateway-api-inference-extension — never touch these kernels, but every latency model they embed (TTFT vs ITL tradeoffs, KV-cache-aware routing, saturation detection) is a function of the two regimes in 4.3. When a router predicts “adding one decode request costs ~X ms/step,” X is the marginal KV bytes divided by bandwidth — i.e., section 4.2.
- xgrammar — its
apply_token_bitmask_inplaceCUDA kernel is a tiny cousin of the sampling kernels here: same “mutate logits on device, never sync to host” principle, applied just before FlashInfer’s sampling kernel draws the token.
7. Hands-on with your RTX 5080 (16 GB)
Platform honesty: WSL2, not native Windows. There are no Windows wheels, CI is Linux, the JIT
pipeline (ninja + nvcc + .so + TVM-FFI load_module) targets Linux shared objects; the stray
sys.platform == "win32" branch in flashinfer/autotuner.py:90 is not a support statement. Use
WSL2 Ubuntu + CUDA toolkit 12.8+ (SM 12.0 needs it). Clone inside the WSL filesystem (~/), not
under /mnt/c/ — JIT compilation hammers the filesystem.
Your GPU is supported. RTX 5080 = Blackwell SM 12.0, listed in README.md:73. The fa2
decode/prefill paths, cascade, and sampling all work; Hopper-only fa3 (sm90a) and B200-class MLA
(sm100a) kernels won’t be available — flashinfer.utils.is_sm120a_supported is the gate tests
use. Setup:
# inside WSL2 Ubuntu with CUDA 12.8+ toolkit
git clone https://github.com/flashinfer-ai/flashinfer.git --recursive && cd flashinfer
pip install --no-build-isolation -e . -v
export FLASHINFER_CUDA_ARCH_LIST="12.0f" # skip detecting/compiling other archs
flashinfer show-config
Microbenchmark 1 — one attention call (and one JIT compile). Run the README quickstart
(README.md:123-133): flashinfer.single_decode_with_kv_cache(q, k, v) with a 2048-token KV.
First call stalls while nvcc builds the kernel — watch ~/.cache/flashinfer/ populate
(generated/ sources, cached_ops/*.so). Second call is microseconds. That stall is the JIT
story of 4.5 made visceral.
Microbenchmark 2 — the decode bandwidth sweep. benchmarks/bench_batch_decode.py already
measures the exact thing you study:
# benchmarks/bench_batch_decode.py:70-78
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
print(f"execution time: {ms}ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
Trim the sweep loops at the bottom (:81-95) for 16 GB — the largest configs allocate
batch x seq x 2048 bytes of KV (batch 512 x 16384 tokens ≈ 17 GB, OOM; cap at batch ≤ 128 or
seq ≤ 8192). What you should observe: at batch 1 the achieved GB/s is far below spec (not enough
CTAs — this is what split-KV planning fixes); as batch x seq grows, GB/s climbs and plateaus near
the 5080’s ~960 GB/s GDDR7 ceiling while latency becomes a clean linear function of total KV
bytes. That plateau is the per-token decode cost model your schedulers assume. Bonus run:
switch kv_dtype to torch.float8_e4m3fn and watch step time halve — quantized KV as a
bandwidth lever, with zero scheduler changes upstream.
Sanity math for the plateau: per token per layer, KV bytes = 2(K+V) x num_kv_heads x head_dim x dtype_size = 2 x 4 x 128 x 2 B = 2 KB in this benchmark’s config; batch 64 x seq 4096 → ~537 MB per call → ~0.6 ms at 960 GB/s. Your measurement should land within ~20% of that.
If you can’t run it locally: the next-best path is (1) the FlashInfer paper, arXiv
2501.01005; (2) the two blog posts — flashinfer.ai/2024/02/02/cascade-inference.html and
flashinfer.ai/2025/03/10/sampling.html — which are the prose versions of 4.4 and 4.6; (3)
docs/tutorials/kv_layout.rst and recursive_attention.rst in-repo; (4) read
benchmarks/bench_batch_decode.py and predict its output from the napkin math above — the
prediction exercise teaches the cost model even without the GPU.