Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

The metrics that define the SLO: TTFT, ITL/TPOT, throughput, goodput, cost

Chapter 1 left us with a request that is not a single RPC but a long-lived state machine: it queues, it prefills, it decodes one token at a time, and it streams those tokens back over a connection that stays open for seconds. That shape has a consequence we now have to confront head-on. A stateless service has essentially one latency number that matters, request latency, and you reason about its distribution. An autoregressive request has at least two, and they are not the same kind of thing. One measures how long you wait before anything happens. The other measures how smoothly things keep happening once they start. They are produced by different phases of the computation, they are bounded by different hardware resources (Chapter 3 will name those resources precisely), and worst of all, they trade against each other and against throughput. You cannot maximize all three. So before we can optimize anything, we have to agree on what “fast” even means, and the honest answer is that the word is meaningless until you have named a metric, a percentile, and a load.

This chapter is about getting that vocabulary exactly right, because every later chapter is, at bottom, a story about moving one of these numbers at the expense of another. The good news for a traffic-infrastructure engineer is that none of the queueing intuition you already have is wasted. The bad news is that you have to apply it to two coupled queues at once, and the denominator you divide cost by is not requests, it is GPU-seconds.

Two latency SLOs, not one

Start with the two latencies, because they are the ones people conflate. The first is time to first token, TTFT: the interval from when the request arrives at the engine to when the user sees the first generated token. It is dominated by two things, the time the request spends waiting in the queue before it is scheduled (call this the queue delay), and the cost of the prefill forward pass over the whole prompt (the one big computation that reads every prompt token and produces the first output token). The second is inter-token latency, ITL, sometimes reported as its per-request average time per output token, TPOT: the gap between consecutive streamed tokens during the decode phase (the long tail of one-token-at-a-time steps that follows prefill). The two latencies live in two different stretches of a request’s life, and the diagram below traces that life from arrival to the last streamed token, marking exactly which interval each metric measures.

sequenceDiagram
    participant U as User / client
    participant Q as Engine queue
    participant E as GPU engine
    U->>Q: request arrives (start time st)
    Note over Q: queue delay (waiting to be scheduled)
    Q->>E: scheduled
    Note over E: prefill, one pass over the whole prompt
    E-->>U: token 1
    Note over U,E: TTFT = queue delay + prefill (measured once)
    Note over E: decode step
    E-->>U: token 2
    Note over U,E: ITL sample = gap before this token
    Note over E: decode step
    E-->>U: token 3
    Note over U,E: ITL sample
    Note over E: decode continues
    E-->>U: token N (last)
    Note over U,E: TPOT = (latency - TTFT) / (N - 1)

Read the diagram top to bottom as the timeline of a single request. Everything above the first arrow out of the engine is folded into TTFT; everything after it is the territory of ITL and TPOT. That single visual split is the whole reason one latency number is not enough.

vLLM measures both, and the cleanest way to see the definitions is in the offline benchmark client, which times them from the outside exactly as a user’s browser would. In vllm/benchmarks/lib/endpoint_request_func.py, each streamed chunk is stamped with timestamp = time.perf_counter(), and the first chunk is then treated specially:

# First token
if ttft == 0.0:
    ttft = timestamp - st
    output.ttft = ttft

# Decoding phase
else:
    output.itl.append(timestamp - most_recent_timestamp)

most_recent_timestamp = timestamp

That is the whole distinction. The code keeps a running most_recent_timestamp; the first chunk to arrive (when ttft is still its sentinel 0.0) is timed against the request’s start time st and recorded as the one TTFT, and every chunk after that records the gap since the previous chunk as one ITL sample. TTFT is therefore measured exactly once per request; ITL is measured once per token-after-the-first. A request that emits $N$ output tokens produces exactly one TTFT and $N - 1$ ITLs, so a 500-token request yields one TTFT and 499 ITLs. This asymmetry in sample counts matters later when we talk about percentiles: a single slow request floods the ITL pool with hundreds of samples but contributes only one TTFT sample, so the two distributions are not weighted the same way even when they come from the same traffic.

TPOT is then just the decode duration amortized over the output tokens, computed per request in vllm/benchmarks/serve.py:

tpot = 0
if output_len > 1:
    latency_minus_ttft = outputs[i].latency - outputs[i].ttft
    tpot = latency_minus_ttft / (output_len - 1)
    tpots.append(tpot)

Note the output_len - 1: the first token’s cost lives in TTFT, so TPOT divides the remaining latency by the remaining tokens. Writing the end-to-end request latency as $L$ and the output length as $N$, that is

$$\text{TPOT} = \frac{L - \text{TTFT}}{N - 1}.$$ ITL is the raw per-step distribution; TPOT is its per-request mean. They answer different questions. ITL tells you whether the stream stutters; TPOT tells you the steady-state rate a given request saw. A request can have a fine TPOT and still stutter badly if a few ITLs spiked, which is exactly what happens when a long prefill from another request lands in the same step and stalls everyone’s decode. Hold that thought, it is the entire motivation for chunked prefill in Chapter 6.

The engine reports the same quantities internally, not from the client’s vantage but from the engine core’s. In vllm/v1/metrics/stats.py, the per-request state object carries the timestamps, and the first-token interval is computed against the request’s arrival time:

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.first_token_latency = first_token_latency

and every later token contributes an ITL measured between engine-core timestamps:

itl = engine_core_timestamp - req_stats.last_token_ts
self.inter_token_latencies_iter.append(itl)

Source: vllm/v1/metrics/stats.py

These feed the Prometheus histograms in vllm/v1/metrics/loggers.py: vllm:time_to_first_token_seconds and vllm:inter_token_latency_seconds, plus a per-request vllm:request_time_per_output_token_seconds. It is worth a glance at the bucket boundaries, because the people who wrote them encoded their expectations about what these distributions look like. TTFT buckets run from a millisecond out to thousands of seconds:

histogram_time_to_first_token = self._histogram_cls(
    name="vllm:time_to_first_token_seconds",
    documentation="Histogram of time to first token in seconds.",
    buckets=[
        0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
        0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0,
        640.0, 2560.0,
    ],

The fact that TTFT needs buckets out to 2560 seconds and ITL only to 80 tells you something true: TTFT has a far longer tail, because it absorbs queueing delay. Under load, the prefill compute is roughly constant for a given prompt length, but the wait in the queue is unbounded. When people complain that “the model got slow,” they almost always mean TTFT, and they almost always mean its tail.

Why one number is a lie: percentile and load

Here is where your distributed-systems instincts pay off, and where most LLM benchmark numbers fall apart. A mean TTFT is nearly useless, for the same reason a mean request latency is useless in any tail-sensitive system: the mean hides the queue. The engine’s histograms exist precisely so you can ask for p50, p95, p99, and the benchmark client reports percentiles by default rather than just means. In serve.py the reported metric set is ttft,tpot,itl and the percentiles are computed straight from the sample arrays:

percentiles_ttft_ms=[
    (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],

But a percentile is still only half a statement. The other half is the load it was measured at. TTFT p99 at one request per second and TTFT p99 at the engine’s saturation point are different universes, because the queueing term dominates as you approach capacity, and it does so superlinearly. The intuition is the standard queueing one: call the arrival rate $\lambda$ and the rate the engine can actually drain $\mu$, and let the utilization be $\rho = \lambda / \mu$. As $\rho \to 1$, the average number of requests waiting ahead of you does not grow linearly, it blows up toward infinity, and TTFT (which contains the queue delay) blows up with it. Far below capacity, adding load barely moves p99; near capacity, a tiny increase in load sends p99 vertical. That is the classic latency-versus-throughput knee, and an LLM engine has one just like any queue. The curve below plots the standard M/M/1 mean queue delay, $W_q \propto \rho/(1-\rho)$, against utilization: notice how it is nearly flat out to $\rho \approx 0.8$ and then goes vertical, which is exactly why a TTFT number quoted without its load is meaningless.

Illustrative: the M/M/1 relation $W_q \propto \rho/(1-\rho)$ gives the correct shape of the knee, not a measured TTFT for any specific engine.

A benchmark that reports “p99 TTFT = 200ms” without naming the request rate, the prompt-length distribution, and the output-length distribution has told you almost nothing, because it has not told you where on that curve, relative to the knee, the number was taken.

The arrival process matters too, and vLLM’s benchmark is careful about it in a way worth copying. Requests are not fired in a tight loop; they are spaced by a configurable stochastic process. The get_request generator in serve.py draws inter-arrival gaps from a gamma distribution whose shape is the --burstiness knob:

theta = 1.0 / (current_request_rate * burstiness)

# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
delay_ts.append(np.random.gamma(shape=burstiness, scale=theta))

At burstiness = 1 the gamma collapses to an exponential, which is to say arrivals are a Poisson process, the standard well-behaved baseline. Push burstiness below 1 and the gaps become more variable, clustering requests into bursts that hammer the queue; push it above 1 and arrivals become more uniform than Poisson, which is gentler than reality. This single parameter lets you ask the question that actually predicts production behavior: not “what is p99 at 10 rps” but “what is p99 at 10 rps when the traffic is bursty.” The prompt and output length distributions come from the dataset layer in vllm/benchmarks/datasets/datasets.py, which can replay real traces (ShareGPT and friends) or synthesize controlled random lengths, because a benchmark on uniform 128-token prompts will lie to you about a workload of mixed 50-token and 8000-token prompts.

Throughput, and why it fights latency

Throughput is the third axis, and on the surface it is the friendly one: tokens per second, the thing you are paying for. vLLM tracks it as a rate over an interval in vllm/v1/metrics/loggers.py, splitting prompt (prefill) tokens from generation (decode) tokens because they cost differently:

prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(self.num_generation_tokens, now)

The reason throughput deserves its own axis, rather than being a derived consequence of latency, is the mechanism Chapter 5 will build in full: continuous batching. The engine raises throughput by packing more sequences into each forward pass, because a decode step is memory-bandwidth-bound (Chapter 3), so reading the weights once to serve a batch of 200 sequences is dramatically more efficient per token than reading them for one. Larger batches mean higher token throughput and better GPU utilization.

And there is the conflict, stated plainly. The same batching that lifts throughput lengthens latency. A bigger batch takes longer per step, so every sequence’s ITL grows. Admitting a large prefill into the batch to keep the GPU busy stalls the decodes already running, spiking their ITL and the new request’s neighbors’ TTFT. The knob that governs this directly is the per-step token budget, max_num_batched_tokens, which Chapter 5 introduces as the scheduler’s rate limiter; turning it up buys throughput and costs tail latency, turning it down does the reverse. There is no setting that wins both. Because both effects are continuous functions of batch size, the tradeoff is really two curves, not a fork: the chart below sweeps batch size on the x-axis and plots aggregate token throughput against per-sequence ITL on the two y-axes.

Illustrative: a simple model where step time grows linearly with batch size (weights read once, activation work scales with the batch). The shapes (saturating throughput, near-linearly rising ITL) are the real behavior; the absolute numbers are not measured.

The two curves make the bind concrete: throughput climbs steeply at small batches but flattens as the weight read is fully amortized, while ITL just keeps rising, so every step toward cheaper, higher-throughput operation also pushes tail latency up. No single batch size wins both axes. This is not an implementation wart to be optimized away. It is the central tension of the whole problem, and it was formalized most cleanly by the Sarathi-Serve work (arXiv:2403.02310), which named the throughput-latency tradeoff explicitly and showed that the prefill-decode interference is the thing you have to schedule around. We will quote its mechanism, chunked prefill, in Chapter 6; here it earns its place as the paper that says out loud that you cannot have it all.

Goodput: latency and throughput in one number

If throughput is “tokens per second” and latency is “but were they fast enough,” the obvious move is to combine them, and that combination has a name: goodput. Goodput is the rate of requests that complete while meeting all of their SLOs simultaneously. A request that finishes but blew its TTFT budget counts toward throughput and against you in goodput. This is the metric that actually correlates with a happy user and a defensible cost, and it is why the benchmark supports it directly.

The definition in vllm/benchmarks/serve.py is admirably blunt. You pass SLO thresholds as --goodput ttft:200 tpot:50 and the client checks each completed request against every threshold:

for req_metric in zip(*valid_metrics):
    is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
    if is_good_req:
        good_completed += 1

all(...) is the load-bearing word: a request is “good” only if it satisfies the TTFT bound and the TPOT bound and the end-to-end bound, whichever you specified. The diagram below traces one completed request through that gate to show why “good” is a conjunction, not an average: failing any single bound knocks the whole request out of the count, no matter how comfortably it cleared the others.

flowchart TD
    R["completed request"] --> A{"TTFT under budget?"}
    A -->|"no"| BAD["not counted (throughput only)"]
    A -->|"yes"| B{"TPOT under budget?"}
    B -->|"no"| BAD
    B -->|"yes"| C{"end-to-end under budget?"}
    C -->|"no"| BAD
    C -->|"yes"| GOOD["good_completed += 1"]
    GOOD --> RATE["goodput = good_completed / wall_clock_seconds"]

The goodput is then those good completions over wall-clock time, where $C_{\text{good}}$ is the count of SLO-satisfying completions and $T$ is the measurement window:

$$\text{goodput} = \frac{C_{\text{good}}}{T}.$$

request_goodput=good_completed / dur_s,

Source: vllm/benchmarks/serve.py

Goodput is the right optimization target precisely because it refuses to let you cheat. You cannot inflate it by cranking the batch size, because that bloats TTFT and the over-budget requests stop counting. You cannot inflate it by shrinking the batch to protect latency, because then you complete too few requests per second. It forces the engine toward the actual Pareto frontier of the two-latency, one-throughput tradeoff. The framing comes from the disaggregation literature, DistServe (arXiv:2401.09670), which optimized prefill and decode separately for goodput rather than raw throughput, an idea Chapter 17 takes apart; the vLLM CLI help even points you at that paper for the definition. For now the lesson is narrow and important: when you set an SLO, you are not setting one number, you are setting a conjunction, and goodput is how you measure whether the conjunction holds at a given load.

The cost denominator is GPU-seconds

The last metric is the one that pays the bills, and it is the one most often computed wrong. The instinct from stateless serving is cost-per-request, or cost-per-million-tokens as a flat rate. Both are misleading for the same reason: a request does not consume a request’s worth of a GPU, it consumes whatever slice of GPU time its tokens occupied while it was resident, and that slice depends on how well the engine packed it alongside others.

The correct denominator is GPU-seconds: the amount of GPU wall-clock time consumed, summed across the GPUs you rented. A GPU costs a fixed amount per second whether it is saturated or idle, so the only cost lever you have is how many useful tokens you extract per GPU-second, which is to say throughput per GPU, which is to say utilization. An idle GPU still bills, so every second it spends not producing tokens is pure waste; the cheapest possible operation keeps it producing useful tokens every second it is alive. This reframes every optimization in the book as a cost optimization. Prefix caching (Chapter 7) is cheaper because it skips recomputing shared tokens, spending fewer GPU-seconds per request. Quantization (Chapter 12) is cheaper because it reads fewer bytes per token off memory, fitting more sequences per GPU-second. Speculative decoding (Chapter 13) is cheaper because it verifies several tokens per forward pass. None of these changes the per-request price directly; all of them change the GPU-seconds in the denominator.

This is also why the two latency SLOs reappear inside the cost story. The cheapest way to run is at maximum batch size and maximum utilization, which is exactly the latency-hostile end of the tradeoff. Your SLOs cap how far toward cheap you are allowed to go. The signals you watch to know where you actually are on that curve are the scheduler gauges in vllm/v1/metrics/loggers.py, the queue depth and the cache pressure:

self.gauge_scheduler_running[engine_idx].set(scheduler_stats.num_running_reqs)
...
self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage)

vllm:num_requests_running is the number of sequences currently in the batch (your batch occupancy), vllm:num_requests_waiting is the queue depth (your backlog), and vllm:kv_cache_usage_perc is how full the KV cache is, which is how close the engine is to running out of the memory that bounds concurrency (Chapter 4 explains why that memory, not compute, is usually the wall). Read together they locate you on the knee from the previous section. When waiting is near zero and KV usage has slack, you are on the flat part of the curve with headroom to spare. When waiting climbs and KV usage pins near 1.0, you are at or past the knee: the engine cannot admit new work without evicting old work, so TTFT’s tail is about to explode and you are either out of headroom or about to start preempting (Chapter 6’s load-shedding mechanism). These same gauges become the inputs to autoscaling in Chapter 20 and to cache-aware routing in Chapter 18, which is the sense in which this chapter’s metrics are the language the rest of the fleet speaks.

Handoff

So “fast” decomposes into a small, precise vocabulary: TTFT and ITL/TPOT as two latency SLOs produced by the two phases, throughput as the rate that fights them, goodput as the conjunction that holds them honest, and GPU-seconds as the denominator that turns all of it into money. We now have the metrics. What we do not yet have is the explanation for why the two phases behave so differently, why prefill is the one that saturates compute and decode is the one that starves on memory bandwidth. That asymmetry is the root cause of nearly every technique that follows, and it is the subject of Chapter 3.

Further reading

  • Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve — arXiv:2403.02310 — formalizes the prefill/decode throughput-latency tension and the goodput framing; read it for why no single batch size wins both metrics.
  • DistServe: Disaggregating Prefill and Decoding for Goodput-optimized LLM Serving — arXiv:2401.09670 — the source of the goodput definition vLLM’s --goodput flag uses; read it for how separating the two phases changes what you can optimize (revisited in Chapter 17).