Speculative Decoding 通过”小模型预测 + 大模型验证”的策略,在保证输出质量的同时显著提升推理速度 🔥
Speculative Decoding accelerates LLM inference by leveraging a small draft model to predict and a large target model to verify, delivering 2-4x speedup without quality loss 🔥
🎯 什么是 Speculative Decoding?
传统 LLM 推理的问题
LLM 推理是自回归解码(Autoregressive Decoding),每个 token 必须等前一个 token 生成后才能生成:
Token 1 → Token 2 → Token 3 → Token 4 → ... → Token N
↓ ↓ ↓ ↓ ↓
等待 等待 等待 等待 完成
瓶颈:每个解码步骤都要访问完整的 Transformer 网络,memory-bound,严重限制吞吐量。
Speculative Decoding 核心思想
用小模型”猜测”,大模型”审核”:
小模型(Draft) → 预测多个token → 大模型(Verify) → 并行验证 → 接受正确的
[t1, t2, t3, t4] [✓, ✓, ✗, ...] → 输出
关键洞察:大模型一次前向传播可以并行验证多个 draft token,而不是逐个生成。
📐 原理详解
算法流程
Step 1:小模型生成 Draft
# 小模型生成候选 token(通常 4-8 个)
draft_tokens = small_model.generate(input_ids, max_new_tokens=K)
# K = speculation budget,一般 4-8
Step 2:大模型并行验证
# 大模型一次性计算所有 draft token 的概率
# 使用KV-cache,draft tokens作为后续token并行处理
logits = large_model(torch.cat([input_ids, draft_tokens]))
draft_probs = softmax(logits[-K:]) # 只看最后K个位置
Step 3:采样 + 接受决策
import torch
def speculative_decode(draft_tokens, draft_probs, large_model_probs, beta=0.5):
"""
draft_probs: 小模型预测的概率 [K, vocab_size]
large_model_probs: 大模型的概率 [K, vocab_size]
beta: 拒绝阈值
"""
accepted = []
for i in range(len(draft_tokens)):
# acceptance ratio
r = large_model_probs[i, draft_tokens[i]] / (draft_probs[i, draft_tokens[i]] + 1e-10)
if torch.rand() < min(r, 1.0):
accepted.append(draft_tokens[i])
else:
# 大模型重新采样
break
# 返回大模型采样结果(包含拒绝点重新采样)
return accepted
数学原理
为什么能加速?
设:
- 小模型推理时间:T_small
- 大模型推理时间:T_large
- Draft 接受率:α(通常 70-90%)
加速比:
Speedup = (T_large × N) / (T_small × K + T_large × ceil(N × (1-α) / K) × K)
其中:
N = 目标 token 数
K = draft length
α = acceptance rate
简化近似:当 K=4, α=0.8 时,Speedup ≈ 2-4x
核心公式:大模型并行验证 K 个 token,虽然计算量和生成一个 token 差不多,但多 token 一次得到。
🧪 代码实现
HuggingFace 实现(Transformers >= 4.78)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# 加载模型
small_model = AutoModelForCausalLM.from_pretrained("small-draft-model")
large_model = AutoModelForCausalLM.from_pretrained("large-target-model")
tokenizer = AutoTokenizer.from_pretrained("small-draft-model")
# Speculative Decoding 配置
assistant_model = AutoModelForCausalLM.from_pretrained("small-draft-model")
prompt = "人工智能技术正在深刻改变"
inputs = tokenizer(prompt, return_tensors="pt")
# 生成(自动使用 speculative decoding)
with torch.no_grad():
outputs = large_model.generate(
**inputs,
assistant_model=assistant_model, # 启用 speculative decoding
num_beams=1,
do_sample=True,
temperature=0.7,
)
result = tokenizer.decode(outputs[0])
print(result)
手动实现版本
import torch
import torch.nn.functional as F
def speculative_decoding(
input_ids,
small_model,
large_model,
max_new_tokens=100,
K=4, # draft length
temperature=1.0
):
"""
手动实现 speculative decoding
"""
small_model.eval()
large_model.eval()
generated = input_ids.clone()
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# 1. 小模型生成 K 个 draft tokens
with torch.no_grad():
draft_logits = small_model(generated)
draft_logits = draft_logits[:, -1, :] / temperature
draft_probs = F.softmax(draft_logits, dim=-1)
# Sample K tokens
draft_tokens = torch.multinomial(draft_probs, num_samples=K)
# 2. 大模型并行验证(扩展序列)
extended = torch.cat([generated, draft_tokens.T], dim=1)
with torch.no_grad():
large_logits = large_model(extended)
large_logits = large_logits[:, -K:, :] / temperature
large_probs = F.softmax(large_logits, dim=-1)
# 3. 接受/拒绝决策
accepted = []
for i in range(K):
ratio = large_probs[0, i, draft_tokens[0, i]] / (draft_probs[0, draft_tokens[0, i]] + 1e-10)
if torch.rand(1).item() < min(ratio.item(), 1.0):
accepted.append(draft_tokens[0, i].item())
else:
# 拒绝:使用大模型采样
new_token = torch.multinomial(large_probs[0, i:i+1], num_samples=1)
accepted.append(new_token.item())
break
generated = torch.cat([
generated,
torch.tensor([[x] for x in accepted], device=generated.device)
], dim=1)
return generated
⚡ 加速效果对比
| 模型组合 | Draft Size | 接受率 | 加速比 |
|---|---|---|---|
| Llama-68B / Llama-7B | 4 | 82% | 2.1x |
| GPT-4 / GPT-3.5-turbo | 4 | 78% | 2.8x |
| Mistral-7B / TinyLlama-1B | 6 | 85% | 3.5x |
实测数据(Mistral-7B + TinyLlama-1B):
传统自回归: 15.2 tokens/s
Speculative: 53.8 tokens/s
加速比: 3.54x
质量对比:
- BLEU 分数差异:< 0.5%
- 人类评估:无法区分
🔧 适用场景与限制
适用场景 ✅
- 对话系统:多轮对话,生成 token 数多
- 长文本生成:文章、报告、代码生成
- 批量推理:同时处理多条请求
不适用场景 ❌
- 极短回复:draft overhead 不值得
- 特定任务:需要精确控制输出格式
- 内存受限:需要加载两个模型
注意事项
# 1. 选择相近大小的 draft 模型
# 太小 → 接受率低,overhead大
# 太大 → 内存占用高,加速效果下降
# 2. K 值选择
K = 4-8 效果最佳
K 太大 → 接受率下降
K 太小 → overhead 大
# 3. Temperature 影响
temperature 高 → 接受率下降(分布更平坦)
temperature 低 → 接受率高但多样性降低
🏗️ 生产环境部署
vLLM 支持
from vllm import LLM, SamplingParams
# 启动时指定 speculative decoding
llm = LLM(
model="mistralai/Mistral-7B-Instruct-v0.2",
tensor_parallel_size=2,
speculative_model="student-1B", # draft 模型
max_num_drafted_tokens=6, # K值
)
# 生成
outputs = llm.generate(["人工智能正在"], SamplingParams(temperature=0.7))
TensorRT-LLM 支持
# TensorRT-LLM 配置
config = {
"speculative_decoding": {
"enabled": True,
"draft_model": "tinyllama-1b",
"max_draft_tokens": 6,
"target_accept_rate": 0.85,
}
}
💡 小结
Speculative Decoding 的核心价值:
传统自回归: 小模型 × N 次
Speculative: 小模型×1 + 大模型×(N/K) ≈ 大模型×(N/K)
加速比 ≈ K/α (K=draft长度, α=接受率)
当 K=6, α=0.85 → 理论加速 ~7x,实际 ~3-4x
关键要点:
- 小模型猜测,大模型验证,并行处理
- 接受率是关键指标(目标 80%+)
- K 值选择:4-8 之间最佳
- 质量几乎无损,加速 2-4x
相关阅读:模型量化技术