Speculative decode is an inferencing optimization where you use a small draft model to generate a sequence of output tokens, then run those draft output tokens through a full-sized model as a single multi-token decode to determine whether the draft model’s output tokens are correct.

This can provide speedup if the draft model is small enough to perform multiple decode steps in the time it would take for the big model to perform a single decode.

Simple example

This is my understanding of how it works.

Let’s say you’re using Llama-3.1-70B as the big model and Llama-3.1-8B as the draft model. You have an input query that’s 100 tokens long.

  1. The 100-token input is sent to both the 70B and 8B models. Each model then
    1. computes keys and values for all 100 tokens in a single forward pass
    2. stores those keys and values in the model’s KV cache
    3. generates a logit matrix of size [100, vocab_size]. The logit vector at position 100 describes the probability distribution for token 101 to be generated.
  2. The 8B draft model decodes, say, 3 draft tokens. Each decode step proceeds as normal:
    1. a single input token of size [1, d_model] goes into the model
    2. computes attention over the cached keys and values
    3. produces a logit vector of size [1, vocab_size]
    4. picks one token from the logit vector to be the 101st token generated
    5. saves the probability that this 101st token would’ve been picked
    6. computes keys and values of that new token and adds them to the KV cache
    7. repeats two more times, resulting in candidate tokens 101, 102, and 103
  3. The full 70B model then verifies those draft tokens:
    1. candidate tokens 101, 102, and 103 are fed into the full model as a single batch of 3 tokens of shape [3, d_model]
    2. the full model computes keys and values for these three tokens and appends them to the KV cache
    3. the full model outputs a logit matrix of shape [3, d_model], representing the logit vectors that follow position 100, 101, and 102
  4. The output tokens are then accepted or rejected based on the logit matrix from the previous step:
    1. The logit vector for position 100 generated by the full 70B model’s initial prefill (step 1) is compared to the first draft token (step 2). Mathematically, this comparison is something like…
      1. Apply softmax to the full 70B model’s logit vector that popped out after position 100 to get a probability distribution
      2. Check to see where in that probability distribution the first draft token lies
      3. Recall the probability that the draft 7B model would’ve picked this draft token from Step 2.5 above. We now have the probability that the 70B model would’ve picked the draft token vs. the probability that the 8B model would’ve picked it
      4. Divide the probability from #2 by the probability from #1 to get the acceptance probability. Cap it at 100%; that is, if the probability that the 70B model would’ve picked it is higher than 8B model, there’s a 100% acceptance rate.
      5. Roll the dice and see if it’s greater than the acceptance probability. If so, keep draft token 1 at position 101.
    2. If we accepted the token at 101, now compare the probability that draft token 102 would’ve been picked by the 70B model to the probability that draft token 102 would’ve been picked by the 8B model. Compute the acceptance probability, then accept/reject this second draft token at position 102.
    3. If we accepted the token at 102, repeat the process for 103.
    4. If we accepted the token at 103, we already have the logits for token 104 from the full 70B model; we calculated it in Step 3! So we can directly pick token 104 as well.

So at its core, the computations for this process are:

  1. Decoding tokens using a small draft model (in our example, ) one at a time to generate draft tokens and the probabilities that those draft tokens would be picked by the draft model. Even though this is three forward passes, the model is small so this happens quickly.
  2. Performing a prefill of tokens in a single forward pass to generate the full probability distribution of the tokens. Even though we’re processing three tokens, there is only one forward pass, so this is much faster than the time it would’ve taken to compute three forward passes. We’re exploiting the fact that most of the time spent computing these logits is actually waiting for model weights and cached keys/values to be copied from HBM into GPU SRAM. The time required to actually compute the logits—whether they be for one token or three—is tiny compared to the time waiting on memory.
  3. Accepting/rejecting each draft token, one at a time.

If a draft token is rejected, so are all the draft tokens after it. For example, if draft token 101 was accepted but draft token 102 was rejected,

  1. We still have the full probability distribution from the 70B model for position 102; this was used to reject that token. So we can pick the correct token 102 from the 70B model for free.
  2. We then begin the speculative decode process starting at position 103.

As Python

Here’s the same example of Step 3 and Step 4 expressed in Python pseudocode. This helped me understand exactly what the inputs and outputs of each part of verification are.

generated    = []
context_len  = input_ids.shape[1]
 
draft_logits, draft_kv = forward(draft_model, input_ids)
full_logits, full_kv = forward(full_model, input_ids)
 
while len(generated) < max_new_tokens:
 
	# 1. Draft K tokens
	draft_tokens = []
	draft_probs = [] # needed for rejection sampling acceptance criterion
	cur_logits = draft_logits[:, -1, :] # [1, vocab_size] — a single row sliced from draft_logits which is [1, seq_len, vocab_size].
 
	for _ in range(K):
		probs = F.softmax(cur_logits, dim=-1) # [1, vocab_size]
		# below: accept the most probable token instead of calculating probabilities - for simplicity
		token = probs.argmax(dim=-1, keepdim=True) # [1, 1]
		draft_tokens.append(token.item())
		draft_probs.append(probs[0]) # [vocab_size]
		cur_logits, draft_kv = forward(draft_model, token, draft_kv)
		cur_logits = cur_logits[:, -1, :] # [1, 1, vocab_size] -> [1, vocab_size]
 
	# 2. Full model verifies all K draft tokens in one batched pass
	draft_tensor = torch.tensor([[*draft_tokens]])
	verify_logits, full_kv_verify = forward(full_model, draft_tensor, full_kv)
 
	# Calculate full model's distribution at each draft position
	# First element: softmax over last token of full_logits
	full_probs = []
	first_prob = F.softmax(full_logits[:, -1, :], dim=-1)[0]
	full_probs.append(first_prob)
 
	# Remaining elements: softmax over tokens 0..K-2 of verify_logits
	for i in range(K - 1):
		prob = F.softmax(verify_logits[:, i, :], dim=-1)[0]
		full_probs.append(prob)
	bonus_probs = F.softmax(verify_logits[:, -1, :], dim=-1)[0] # [vocab_size]
 
	# 3. Accept or reject each draft token
	accepted = []
	n_draft_accepted = 0
	all_accepted = True
 
	for draft_tok, p_draft, p_full in zip(draft_tokens, draft_probs, full_probs):
		# accept the most probable token instead of calculating probabilities - for simplicity
		if draft_tok == p_full.argmax().item():
			accepted.append(draft_tok)
		else:
			accepted.append(p_full.argmax().item()) # corrected token
			all_accepted = False
			break
 
	if all_accepted:
		accepted.append(bonus_probs.argmax().item())
 
	generated.extend(accepted)
	context_len += len(accepted)
 
	if len(generated) >= max_new_tokens:
		break
 
	# 4. Roll back both caches to only cover the positions that were genuinely accepted, discarding the rejected tail
	last_token = torch.tensor([[accepted[-1]]])
 
	full_kv = truncate_kv(full_kv_verify, context_len - len(accepted) + n_draft_accepted)
	full_logits,  full_kv  = forward(full_model,  last_token, full_kv)
 
	# repeat for the draft model
	draft_kv = truncate_kv(draft_kv, context_len - len(accepted) + n_draft_accepted)
	draft_logits, draft_kv = forward(draft_model, last_token, draft_kv)

I also made this into a working script that lets you play with speculative decode to see how different draft and full models to see the potential speedups.