Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

FlashInfer

A study guide to the kernel library that sits at the very bottom of your inference stack. Repo: materials/flashinfer/flashinfer (github.com/flashinfer-ai/flashinfer). All paths below are relative to the repo root; line numbers match this checkout.

1. What it is

FlashInfer is the GPU kernel library and kernel generator that inference engines call for the operations that dominate a serving step: attention (prefill, decode, MLA, cascade/shared-prefix, sparse), sampling (top-k/top-p/min-p on device), plus GEMM/MoE, RoPE, and norms. It is not an engine — it has no scheduler, no server, no model code. It exposes PyTorch-callable ops that are JIT-compiled per configuration (dtype, head_dim, feature flags) and cached as .so files.

Who consumes it:

  • SGLang — its standard attention backend on NVIDIA GPUs (--attention-backend flashinfer) and its on-GPU sampling kernels.
  • vLLM — optional attention backend (VLLM_ATTENTION_BACKEND=FLASHINFER), and the channel through which NVIDIA’s trtllm-gen Blackwell kernels reach vLLM.
  • Also TensorRT-LLM, TGI, MLC-LLM, LightLLM (see README.md:231-242).

The project’s own CLAUDE.md is an unusually good architecture document — read it alongside this guide.

2. Why you care

  • It defines the per-step cost model. Every latency assumption upstream — scheduler token budgets, router ITL estimates, autoscaler capacity curves — bottoms out at: decode step time ≈ (weight bytes + batch KV bytes) / HBM bandwidth; prefill time ≈ FLOPs / tensor-core throughput. Those constants are set by the kernels in this repo. When SGLang’s scheduler decides how many tokens fit in a step, it is budgeting against these kernels’ behavior.
  • The plan/run split is the reason engines look the way they do. Kernels are “planned” on the CPU (workspace layout, work partitioning) once per batch composition, then “run” per layer with zero host decisions — which is what makes CUDA-graph capture, fixed capture batch sizes, and cheap per-layer dispatch possible. It is a literal control-plane/data-plane separation.
  • Cascade attention is the kernel-level payoff of prefix sharing. Prefix caching as you know it deduplicates KV storage. The cascade kernels additionally deduplicate KV bandwidth — the shared prefix is read from HBM once for the whole batch instead of once per request. That’s why prefix-aware routing pays off twice.

3. Architecture map

Four layers, with a hard rule: framework-agnostic kernels never see PyTorch headers.

flashinfer/                  Python package: public API, one module per op family
  decode.py, prefill.py        attention wrappers (the plan/run classes)
  cascade.py, sampling.py      shared-prefix attention; on-GPU sampling
  page.py, mla/, pod.py        KV-page append, DeepSeek MLA, fused prefill+decode
  jit/                         JIT machinery: JitSpec, Jinja codegen, ninja builds
    core.py, env.py            build/load/cache; ~/.cache/flashinfer layout
    attention/modules.py       gen_*_module() generators + URI naming
csrc/                        TVM-FFI bindings + launchers (PyTorch tensor world)
  batch_decode.cu              Plan() and Run() entry points for decode
  *.jinja                      type-specialization templates rendered at JIT time
include/flashinfer/          header-only CUDA kernel templates (raw pointers only)
  page.cuh                     paged_kv_t — the page table struct kernels consume
  attention/decode.cuh         CUDA-core decode kernel (bandwidth-bound GEMV)
  attention/prefill.cuh        tensor-core prefill kernel (compute-bound GEMM)
  attention/scheduler.cuh      DecodePlan/PrefillPlan — CPU-side work partitioning
  attention/cascade.cuh        attention-state merge kernels
  sampling.cuh                 sorting-free top-k/top-p rejection sampling
benchmarks/                  per-kernel microbenchmarks (no model weights needed)
flashinfer-cubin/, flashinfer-jit-cache/   optional prebuilt-kernel packages

Bindings use TVM-FFI rather than torch extensions, so the same compiled kernels are usable from non-PyTorch frameworks. 3rdparty/ vendors CUTLASS and spdlog.

4. Core mechanisms

4.1 The plan/run split — control plane vs data plane

Every batch attention API is a wrapper class with two phases. plan() runs on the CPU once per batch composition (which requests, which lengths); run() executes per layer — 32-80 times per step — with no host-side decisions. If you squint, plan is route computation and run is packet forwarding against an installed FIB.

flashinfer/decode.py:901 (BatchDecodeWithPagedKVCacheWrapper.plan) takes the page-table metadata as host tensors, because the planner needs to read lengths on the CPU, and calls into C++ BatchDecodeWithPagedKVCachePlan (csrc/batch_decode.cu:39). The real work is DecodePlan in include/flashinfer/attention/scheduler.cuh:426:

// include/flashinfer/attention/scheduler.cuh:437
FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages,
                                          new_batch_size, gdy, batch_size, indptr_h, num_qo_heads,
                                          page_size, enable_cuda_graph, stream));
size_t padded_batch_size;
plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.split_kv = split_kv;
padded_batch_size =
    (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size;
plan_info.padded_batch_size = padded_batch_size;
auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] =
    DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages);

What the plan actually decides:

  • Load balancing. Decode at small batch sizes can’t fill the GPU (one CTA per request per KV head). The planner binary-searches the smallest KV chunk size such that splitting long sequences into chunks produces enough CTAs to occupy every SM (PartitionPagedKVCacheBinarySearchMinNumPagePerBatch, include/flashinfer/attention/scheduler.cuh:74). Classic bin-packing, done ahead of time on the CPU.
  • Workspace layout. All scratch (partial outputs for split-KV, index arrays) is carved out of one user-provided 128 MB workspace buffer by an offset allocator — no per-call cudaMalloc.
  • One H2D copy. The whole schedule is staged in pinned host memory and shipped with a single cudaMemcpyAsync (scheduler.cuh:489).

The plan is returned to Python as an opaque vector of int64 offsets (DecodePlanInfo::ToVector, scheduler.cuh:391). At run time, csrc/batch_decode.cu:175 rehydrates raw pointers from it:

// csrc/batch_decode.cu:175
params.request_indices =
    GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset);
params.kv_tile_indices =
    GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset);
params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset);
if (plan_info.split_kv) {
  tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset);
  tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset);
}

CUDA-graph tie-in: plan() explicitly “cannot be used in CUDA Graph or torch.compile” (flashinfer/decode.py:1006), while run() is capturable. Under graphs the grid size must be fixed, so the planner pads the launch to a fixed padded_batch_size and the kernel early-exits surplus blocks via block_valid_mask (include/flashinfer/attention/decode.cuh:423). This is why engines capture graphs at a fixed set of batch sizes and bucket requests into them.

4.2 Paged KV-cache attention: how the page table reaches the kernel

The page table is CSR. Three int32 arrays describe the whole batch: indices (flat list of physical page ids), indptr ([batch+1], request i owns indices[indptr[i]:indptr[i+1]]), and last_page_len (fill level of each request’s last page). The engine (SGLang/vLLM) owns page allocation; FlashInfer only consumes this view. On the C++ side it becomes a 30-line struct passed by value into the kernel:

// include/flashinfer/page.cuh:37
template <typename DType, typename IdType>
struct paged_kv_t {
  uint_fastdiv page_size;
  uint32_t num_heads, head_dim, batch_size;
  uint32_t stride_page, stride_n, stride_h;
  DType* k_data;   // base pointer of the whole page pool
  DType* v_data;
  IdType* indices;        // physical page ids
  IdType* indptr;         // [batch_size + 1] per-request extent into indices
  IdType* last_page_len;  // [batch_size] entries used in the final page
  IdType* rope_pos_offset;
  ...
};

Inside the decode kernel, threads translate logical token positions to physical addresses — a software page-table walk. uint_fastdiv precomputes the magic numbers so the hot loop never does an integer divide:

// include/flashinfer/attention/decode.cuh:480
uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start;
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
  uint32_t q, r;
  paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
  kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
      paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr);
}

The decode main loop (decode.cuh:523-585) is a textbook bandwidth machine: a multi-stage cp_async pipeline streams K tiles and V tiles from HBM into shared memory while the previous tile is being dotted against the single query vector held in registers:

// include/flashinfer/attention/decode.cuh:523 (elided)
for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) {
  cp_async::wait_group<2 * num_stages_smem - 1>();   // wait for oldest K tile
  compute_qk<...>(params, variant, batch_idx, k_smem + ..., q_vec, ...);  // q . k
  cp_async::pred_load<...>(k_smem + ..., paged_kv.k_data + kv_offset[j], ...);  // prefetch next K
  cp_async::commit_group();
  cp_async::wait_group<2 * num_stages_smem - 1>();   // wait for V tile
  update_local_state<...>(v_smem + ..., s, stage_idx, st, tx);  // softmax-weighted V accumulate
  cp_async::pred_load<...>(v_smem + ..., paged_kv.v_data + kv_offset[j], ...);  // prefetch next V
  cp_async::commit_group();
}

Every KV byte is read exactly once; the arithmetic per byte is a handful of FMAs. That is the whole “decode is bandwidth-bound” story in code: the kernel’s job is to keep the memory pipes full, and the planner’s job (4.1) is to make sure there are enough CTAs for that. When a sequence was split, partial results carry (output, log-sum-exp) and a second kernel merges them — the same merge operator cascade attention uses (4.4).

4.3 Prefill vs decode kernel families

Two separate kernel templates exist because the two phases sit on opposite ends of the roofline:

  • Decode (include/flashinfer/attention/decode.cuh): one query token per request. q·K is a GEMV — O(1) FLOPs per KV byte — so it runs on plain CUDA cores with vectorized loads, as above. Tensor cores would sit idle.
  • Prefill (include/flashinfer/attention/prefill.cuh): many query tokens attend to many keys. q·K is a GEMM — O(qo_len) FLOPs per KV byte — so the kernel is built around 16x16x16 tensor-core fragments (mma = matrix-multiply-accumulate instruction):
// include/flashinfer/attention/prefill.cuh:946
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
  if constexpr (std::is_same_v<typename KTraits::DTypeQKAccum, float>) {
    if (mma_d == 0) {
      mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ, MMAMode::kInit>(
          s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
    } else {
      mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ>(
          s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag);
    }
  }
  ...
}

The GQA twist that blurs the line: with grouped-query attention, 4-8 query heads share each KV head, so even a “decode” step has a small matrix (group_size x head_dim) per KV head. Pass use_tensor_cores=True to the decode wrapper and it silently routes through the prefill kernel family (flashinfer/decode.py:1181-1213 calls get_batch_prefill_module) — same page table, different execution engine. Mixed chunked-prefill+decode batches get dedicated paths: POD-attention (csrc/batch_pod.cu) fuses both phases into one kernel, and the newer holistic persistent scheduler lives in csrc/batch_attention.cu / include/flashinfer/attention/persistent.cuh.

For your mental model: a serving step’s cost = one compute-bound prefill kernel over the new tokens + one bandwidth-bound decode kernel over all resident KV. Chunked prefill exists upstream precisely to interleave these two regimes without starving either resource.

4.4 Cascade attention: the kernel payoff of prefix sharing

FlashInfer generalizes attention into an attention state (v(I), s(I)) — the weighted value and log-sum-exp over an index set I — with an associative, commutative merge operator (docs/tutorials/recursive_attention.rst). Attention over I ∪ J equals merging the states of I and J, computed in any order, on any device. The merge kernel is just a numerically-stable softmax renormalization:

// include/flashinfer/attention/cascade.cuh:98 (MergeStateInPlaceKernel)
float s_val = s[pos * num_heads + head_idx];
float s_other_val = s_other[pos * num_heads + head_idx];
float s_max = max(s_val, s_other_val);
s_val = math::ptx_exp2(s_val - s_max);
s_other_val = math::ptx_exp2(s_other_val - s_max);
float scale = s_val / (s_val + s_other_val);
float other_scale = s_other_val / (s_val + s_other_val);
for (uint32_t i = 0; i < vec_size; ++i) {
  v_vec[i] = scale * v_vec[i] + other_scale * v_other_vec[i];
}

MultiLevelCascadeAttentionWrapper (flashinfer/cascade.py:226) exploits this for shared prefixes: level 0 treats the entire batch as one query block attending to the shared prefix’s KV; lower levels handle each request’s unique suffix; then states are merged:

# flashinfer/cascade.py:547
out, lse = self._batch_prefill_wrappers[-1].run(q, paged_kv_cache, return_lse=True)
for wrapper in self._batch_prefill_wrappers[:-1]:
    out_i, lse_i = wrapper.run(q, paged_kv_cache, return_lse=True)
    merge_state_in_place(out, lse, out_i, lse_i)
return out

Why it’s fast: in the level-0 pass, all batch queries attend the prefix together (it’s matrix-shaped, tensor-core friendly), and the prefix KV is fetched from HBM once for the whole batch instead of once per request. SGLang’s radix tree shares prefix pages (capacity win); cascade shares prefix traffic (bandwidth win). Same split-then-merge machinery also powers split-KV decode (4.2) and, conceptually, any cross-device attention sharding.

4.5 JIT compilation: kernels generated and cached at runtime

The kernel space is combinatorial — dtype(q) x dtype(kv) x dtype(o) x head_dim x position encoding x sliding window x logits soft-cap x backend x GPU arch. AOT-compiling all of it is infeasible, so FlashInfer compiles on first use. Each combination gets a deterministic URI:

# flashinfer/jit/attention/modules.py:78
return (
    f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
    f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
    f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
    f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
    f"head_dim_qk_{head_dim_qk}_"
    f"head_dim_vo_{head_dim_vo}_"
    f"posenc_{pos_encoding_mode}_"
    f"use_swa_{use_sliding_window}_"
    f"use_logits_cap_{use_logits_soft_cap}"
)

A gen_*_module() function (e.g. gen_batch_decode_module, flashinfer/jit/attention/modules.py:915) renders a Jinja template (csrc/batch_decode_customize_config.jinja) into a type-specialized config header, copies the .cu sources, and returns a JitSpec. Building means: emit build.ninja, run nvcc, load the .so through TVM-FFI — with an AOT short-circuit:

# flashinfer/jit/core.py:307
def build_and_load(self):
    if self.is_aot:
        return self.load(self.aot_path)
    # Guard both build and load with the same lock to avoid race condition
    with FileLock(self.lock_path, thread_local=False):
        so_path = self.jit_library_path
        verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
        self.build(verbose, need_lock=False)
        result = self.load(so_path)
    return result

Caching is two-level: @functools.cache in-process (see get_sampling_module, flashinfer/sampling.py:66) and on disk at ~/.cache/flashinfer/<version>/<archs>/cached_ops/ (flashinfer/jit/env.py:135-150 — version and arch list are part of the path, so upgrades and new GPUs invalidate cleanly). Recompilation triggers: a never-seen parameter combination, source file hash change, compiler flag change, arch list change, version change.

Ops note for the layers you run: the first request with a new model shape stalls for seconds-to- minutes of nvcc — this is why engines warm up before serving traffic, why flashinfer-jit-cache and flashinfer-cubin wheels exist (prebuilt .sos / downloaded trtllm-gen cubins), and why FLASHINFER_DISABLE_JIT=1 (flashinfer/jit/core.py:290) is used in locked-down prod images to fail fast instead of compiling.

4.6 Sampling: selection never leaves the GPU

After attention produces logits, token selection (temperature, top-k, top-p, min-p) runs in a single fused kernel per batch — no sort of the 128k-entry vocab, no host round-trip. The trick is rejection sampling with a shrinking pivot window: sample a token from the full distribution, then check in one block-wide reduction whether the top-p constraint would have admitted it; if not, tighten the admissible probability range and resample. One CTA per batch row; the do/while converges in a few rounds:

// include/flashinfer/sampling.cuh:1070 (TopPSamplingFromProbKernel, inside do { ... })
if (aggregate_gt_pivot_0 < top_p) {
  // case 1: pivot_0 accepted
  break;
}
if (aggregate_gt_pivot_1 < top_p) {
  // case 2: pivot_0 rejected, pivot_1 accepted
  low = pivot_0;
  high = pivot_1;
  q = aggregate_gt_pivot_0;
} else {
  // case 3: pivot_0 rejected, pivot_1 rejected
  low = pivot_1;
  q = aggregate_gt_pivot_1;
}
} while (low < high);

(pivot_0 is the sampled token’s probability, pivot_1 a midpoint — two candidate thresholds evaluated per pass over the vocab, so each round halves the search interval.) Python entry points: top_p_sampling_from_probs (flashinfer/sampling.py:976), top_k_top_p_sampling_from_probs (flashinfer/sampling.py:1579), plus chain_speculative_sampling for draft-token verification in speculative decoding. Bindings live in csrc/sampling.cu / csrc/renorm.cu. The end-to-end consequence you already know from the serving side: the only thing crossing PCIe per step is a handful of int32 token ids.

5. Suggested reading path

  1. README.md + docs/tutorials/kv_layout.rst — vocabulary: NHD/HND layouts, ragged vs paged, indptr conventions. Then docs/tutorials/recursive_attention.rst (10 minutes, the merge operator).
  2. flashinfer/decode.py:712-900BatchDecodeWithPagedKVCacheWrapper.__init__: workspace buffers, CUDA-graph buffer pinning. This class is the contract engines program against.
  3. flashinfer/decode.py:901-1281 (plan) and :1352-1660 (run) — note what is host metadata vs device tensors, and what gets cached on the wrapper.
  4. csrc/batch_decode.cu (200 lines, read in full) — Python tensors become paged_kv_t + raw pointers; plan info becomes workspace offsets.
  5. include/flashinfer/attention/scheduler.cuh:60-130 and :366-492 — the planner. This is the most “your world” file in the repo: occupancy estimation, binary-search load balancing, offset allocation.
  6. include/flashinfer/page.cuh:37-210 — the page table struct and address translation.
  7. include/flashinfer/attention/decode.cuh:380-608 — one full bandwidth-bound kernel. Skim for the pipeline structure, not the index arithmetic.
  8. include/flashinfer/attention/prefill.cuh — skim compute_qk (:878) just to see the mma fragments; then flashinfer/cascade.py:226-560 + include/flashinfer/attention/cascade.cuh.
  9. flashinfer/sampling.py + include/flashinfer/sampling.cuh:968-1090.
  10. flashinfer/jit/core.py + flashinfer/jit/attention/modules.py:915-960, then look inside ~/.cache/flashinfer/ after running anything.
  11. benchmarks/bench_batch_decode.py — 100 lines that tie the whole cost model together.

End-to-end trace to cement it: one wrapper.run(q, kv_cache) decode call — decode.py:1352csrc/batch_decode.cu:83BatchDecodeWithPagedKVCacheDispatched (include/flashinfer/attention/decode.cuh:741) → BatchDecodeWithPagedKVCacheKernel (decode.cuh:613) → optional VariableLengthMergeStates (cascade.cuh:687).

6. Connections to your other study repos

  • sglang — primary consumer. python/sglang/srt/layers/attention/flashinfer_backend.py wraps these exact wrapper classes; its init_forward_metadata is where plan() gets called each scheduler step, and SGLang’s “capture batch sizes” map to the CUDA-graph padding you saw in 4.1. Its sampler calls flashinfer.sampling.*. SGLang’s radix cache decides what pages are shared; FlashInfer decides how cheaply they’re read.
  • vllm — optional attention backend (VLLM_ATTENTION_BACKEND=FLASHINFER); on Blackwell it is also the delivery vehicle for trtllm-gen attention cubins. Comparing vLLM’s FlashInfer backend wrapper with SGLang’s is a good exercise: same plan/run contract, different metadata plumbing.
  • nano-vllm — uses flash-attn instead of FlashInfer; reading its attention call next to SGLang’s FlashInfer backend shows exactly what the plan/run + page-table contract buys (and costs) versus a single monolithic kernel call.
  • dynamo / llm-d / gateway-api-inference-extension — never touch these kernels, but every latency model they embed (TTFT vs ITL tradeoffs, KV-cache-aware routing, saturation detection) is a function of the two regimes in 4.3. When a router predicts “adding one decode request costs ~X ms/step,” X is the marginal KV bytes divided by bandwidth — i.e., section 4.2.
  • xgrammar — its apply_token_bitmask_inplace CUDA kernel is a tiny cousin of the sampling kernels here: same “mutate logits on device, never sync to host” principle, applied just before FlashInfer’s sampling kernel draws the token.

7. Hands-on with your RTX 5080 (16 GB)

Platform honesty: WSL2, not native Windows. There are no Windows wheels, CI is Linux, the JIT pipeline (ninja + nvcc + .so + TVM-FFI load_module) targets Linux shared objects; the stray sys.platform == "win32" branch in flashinfer/autotuner.py:90 is not a support statement. Use WSL2 Ubuntu + CUDA toolkit 12.8+ (SM 12.0 needs it). Clone inside the WSL filesystem (~/), not under /mnt/c/ — JIT compilation hammers the filesystem.

Your GPU is supported. RTX 5080 = Blackwell SM 12.0, listed in README.md:73. The fa2 decode/prefill paths, cascade, and sampling all work; Hopper-only fa3 (sm90a) and B200-class MLA (sm100a) kernels won’t be available — flashinfer.utils.is_sm120a_supported is the gate tests use. Setup:

# inside WSL2 Ubuntu with CUDA 12.8+ toolkit
git clone https://github.com/flashinfer-ai/flashinfer.git --recursive && cd flashinfer
pip install --no-build-isolation -e . -v
export FLASHINFER_CUDA_ARCH_LIST="12.0f"   # skip detecting/compiling other archs
flashinfer show-config

Microbenchmark 1 — one attention call (and one JIT compile). Run the README quickstart (README.md:123-133): flashinfer.single_decode_with_kv_cache(q, k, v) with a 2048-token KV. First call stalls while nvcc builds the kernel — watch ~/.cache/flashinfer/ populate (generated/ sources, cached_ops/*.so). Second call is microseconds. That stall is the JIT story of 4.5 made visceral.

Microbenchmark 2 — the decode bandwidth sweep. benchmarks/bench_batch_decode.py already measures the exact thing you study:

# benchmarks/bench_batch_decode.py:70-78
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
print(f"execution time: {ms}ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")

Trim the sweep loops at the bottom (:81-95) for 16 GB — the largest configs allocate batch x seq x 2048 bytes of KV (batch 512 x 16384 tokens ≈ 17 GB, OOM; cap at batch ≤ 128 or seq ≤ 8192). What you should observe: at batch 1 the achieved GB/s is far below spec (not enough CTAs — this is what split-KV planning fixes); as batch x seq grows, GB/s climbs and plateaus near the 5080’s ~960 GB/s GDDR7 ceiling while latency becomes a clean linear function of total KV bytes. That plateau is the per-token decode cost model your schedulers assume. Bonus run: switch kv_dtype to torch.float8_e4m3fn and watch step time halve — quantized KV as a bandwidth lever, with zero scheduler changes upstream.

Sanity math for the plateau: per token per layer, KV bytes = 2(K+V) x num_kv_heads x head_dim x dtype_size = 2 x 4 x 128 x 2 B = 2 KB in this benchmark’s config; batch 64 x seq 4096 → ~537 MB per call → ~0.6 ms at 960 GB/s. Your measurement should land within ~20% of that.

If you can’t run it locally: the next-best path is (1) the FlashInfer paper, arXiv 2501.01005; (2) the two blog posts — flashinfer.ai/2024/02/02/cascade-inference.html and flashinfer.ai/2025/03/10/sampling.html — which are the prose versions of 4.4 and 4.6; (3) docs/tutorials/kv_layout.rst and recursive_attention.rst in-repo; (4) read benchmarks/bench_batch_decode.py and predict its output from the napkin math above — the prediction exercise teaches the cost model even without the GPU.