RTP-LLM Speculative Engine 深度解析
基于代码分析的推测解码引擎实现详解
项目概览
RTP-LLM 是阿里基础模型推理团队开发的大语言模型推理加速引擎,广泛应用于阿里巴巴集团内部的多个业务场景(淘宝、天猫、闲鱼、菜鸟等)。其中,speculative_engine 模块实现了推测解码技术,通过"草稿模型预测 + 主模型验证"的方式显著提升推理吞吐量。
推测解码的核心思想
传统的大模型推理每步只能生成 1 个 token,而推测解码的核心思路是:
- 草稿模型快速预测:使用轻量级的草稿模型一次性预测 N 个后续 token
- 主模型并行验证:主模型并行对这 N 个 token 进行验证
- 拒绝采样:根据草稿模型和主模型的概率分布,决定接受多少 token
传统解码: [Input] -> Main Model -> [token1] -> [Input+token1] -> Main Model -> [token2] -> ...
推测解码: [Input] -> Draft Model -> [t1, t2, ..., tN] -> Main Model -> [p(t1), p(t2), ..., p(tN)]
-> Sampler -> [accepted tokens]
模块架构
speculative_engine/
├── SpeculativeEngine.h/cc # 主引擎类
├── SpeculativeScheduler.h/cc # 调度器
├── propose_executor/ # 提案执行器(草稿模型)
│ ├── ProposeExecutor.h/cc # 基类
│ ├── VanillaExecutor.h/cc # Vanilla 实现
│ ├── MTPExecutor.h/cc # MTP 实现
│ ├── EagleExecutor.h/cc # Eagle 实现
│ └── MTPBatchStreamProcessor.h/cc # MTP 批量处理器
├── score_executor/ # 评分执行器(主模型)
│ ├── ScoreExecutor.h/cc
│ └── ScoreBatchStreamProcessor.h/cc
└── speculative_sampler/ # 推测采样器
└── SpeculativeSampler.h/cc
三种推测解码类型
1. Vanilla 推测解码
传统的推测解码方式,草稿模型一次性生成 N 个 token。
// SpeculativeEngine::spStep()
THROW_IF_STATUS_ERROR(propose_executor_->propose(streams)); // 生成 N 个 token
THROW_IF_STATUS_ERROR(score_executor_->score(streams)); // 并行验证 N 个 token
CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(streams)); // 采样
2. MTP 推测解码
Multi-Token Prediction 级联结构,通过多个轻量级模块逐步预测。
// MTPExecutor::propose()
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream); // 状态右移
}
mtp_executors_[i]->process(propose_streams); // 每个模块执行一次
}
MTP 模型结构:
// MTPModel::embeddingPost()
auto e_norm = device_->layernorm(LayernormParams(hidden_states, ..., *weights_.layers[0].enorm));
auto h_norm = device_->layernorm(LayernormParams(last_hidden_states, ..., *weights_.layers[0].hnorm));
torch::Tensor cat_tensor = torch::cat({h_norm_tensor, e_norm_tensor}, -1);
auto final_hidden_states = device_->gemm({*cat_buffer, *(weights_.layers[0].eh_proj->kernel)});
3. Eagle 推测解码
与 MTP 类似,但使用同一个模块循环执行。
// EagleExecutor::propose()
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream);
}
mtp_executors_[0]->process(propose_streams); // 始终使用第一个模块
}
Eagle3 推测解码(最新变体)
Eagle3 是 Eagle 的增强版本,在融合机制上做了改进。
// Eagle3Model::embeddingPost()
if (last_hidden_states == nullptr) {
// 首次执行,复制 hidden states(Q_proj 需要 2*hidden_size)
BufferPtr duplicate_hidden = device_->clone({*torchTensor2Buffer(
torch::cat({Buffer2torchTensor(hidden_states, false),
Buffer2torchTensor(hidden_states, false)}, 1)),
AllocationType::DEVICE});
return {duplicate_hidden, hidden_states};
}
// 投影上一级隐藏状态
auto proj_last_hidden_states = device_->gemm({*inputs.last_hidden_states,
*(weights_.layers[0].eagle3_fc_proj->kernel)});
// 归一化
auto proj_norm = device_->layernorm(LayernormParams(proj_last_hidden_states, ...,
*weights_.layers[0].eagle3_fc_norm, ...));
auto input_norm = device_->layernorm(LayernormParams(hidden_states, ...,
*weights_.layers[0].eagle3_input_norm, ...));
// 拼接
torch::Tensor cat_tensor = torch::cat({input_norm_tensor, proj_norm_tensor}, -1);
MTP vs Eagle vs Eagle3 对比
执行流程对比
| 特性 | MTP | Eagle | Eagle3 |
|---|---|---|---|
| 执行器数量 | N 个(每个模块独立) | 1 个(循环使用) | 1 个(循环使用) |
| 执行代码 | mtp_executors_[i] | mtp_executors_[0] | mtp_executors_[0] |
| 模块复用 | 每步使用不同模块 | 同一模块重复执行 | 同一模块重复执行 |
模型结构对比
┌─────────────────────────────────────────────────────────────────┐
│ MTP Layer │
├─────────────────────────────────────────────────────────────────┤
│ current embedding ──> LayerNorm(enorm_weights) ──┐ │
│ │ │
│ previous hidden ──> LayerNorm(hnorm_weights) ──┤ │
│ │ │
│ ├─> Concat ─> GEMM(eh_proj) ─> 输出
│ │ │
│ [norm + norm + concat + linear] │ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ Eagle3 Layer │
├──────────────────────────────────────────────── ─────────────────┤
│ current embedding ──> LayerNorm(input_norm_weights) ──┐ │
│ │ │
│ previous hidden ──> GEMM(fc_proj) ─> LayerNorm(fc_norm_weights) │
│ │ │
│ ├─> Concat ─> 输出
│ │ │
│ [norm + proj + norm + concat] │ │
└─────────────────────────────────────────────────────────────────┘
关键权重对比
| 权重 | MTP | Eagle3 | 说明 |
|---|---|---|---|
enorm / input_norm | ✓ | ✓ | 当前 embedding 的归一化 |
hnorm | ✓ | - | 上一级隐藏状态的归一化 |
fc_proj | - | ✓ | 上一级隐藏状态的投影 |
fc_norm | - | ✓ | 投影结果的归一化 |
eh_proj | ✓ | - | MTP 融合投影矩阵 |
核心差异
-
Hidden State 处理方式
- MTP: 直接对
last_hidden_states做 LayerNorm - Eagle3: 先通过 GEMM 投影,再做 LayerNorm(增强表达能力)
- MTP: 直接对
-
融合路径
- MTP:
[norm_e, norm_h] → concat → linear - Eagle3:
[norm_input, norm_proj] → concat
- MTP:
-
初始状态
- MTP: 无特殊处理
- Eagle3: 首次执行时复制 hidden states(因为 Q_proj 需要
2 * hidden_size)
设计理念差异
| 方面 | MTP | Eagle3 |
|---|---|---|
| 理念 | 简单直接的双路融合 | 通过投影增强表达 |
| 参数量 | 较少(仅一个 eh_proj) | 略多(多一个 fc_proj) |
| 计算复杂度 | 2 个 LayerNorm + 1 GEMM | 2 个 LayerNorm + 2 GEMM |
| 适用场景 | 资源受限、追求简单 | 追求更高接受率 |
三种类型总结
| 类型 | 执行器数量 | 模型结构 | 特点 |
|---|---|---|---|
| Vanilla | 1 个 | 标准 GPT | 一次性生成 N 个 token |
| MTP | N 个 | MTP (norm+norm+linear) | 每步独立模块,级联预测 |
| Eagle | 1 个 | MTP | 单模块循环使用 |
| Eagle3 | 1 个 | Eagle3 (norm+proj+norm) | 增强融合能力 |
拒绝采样算法
SpeculativeSampler 实现了拒绝采样算法,决定接受多少草稿 token:
// SpeculativeSampler::stochasticSample()
for (int i = 0; i < propose_step; i++) {
float p_draft = propose_all_probs[proposed_token_id];
float p_score = score_all_probs[proposed_token_id];
float accept_prob = p_score / p_draft; // 接受概率
float random = Uniform(0, 1);
if (random <= accept_prob) {
// 接受草稿 token
accepted_token = proposed_token_id;
accept_len++;
} else {
// 拒绝,从修正分布重新采样
new_dist = max(0, p_score - p_draft);
new_dist = new_dist / new_dist.sum();
accepted_token = sample(new_dist);
break; // 停止接受后续 token
}
}
关键点:
- 接受概率
p_score / p_draft:草稿模型和主模型概率越接近,接受率越高 - 当拒绝时,从主模型的剩余分布中采样
- 拒绝后必须停止,确保输出分布与主模型一致
NormalEngine vs SpeculativeEngine 对比
| 对比维度 | NormalEngine | SpeculativeEngine |
|---|---|---|
| 核心思想 | 单模型串行推理 | 草稿模型预测 + 主模型验证 |
| 模型数量 | 1 个主模型 | 1 个草稿模型 + 1 个主模型 |
| 执行流程 | 每步生成 1 个 token | 草稿模型生成 N 个 token,主模型并行验证 |
| Executor | 单个 NormalExecutor | ProposeExecutor + ScoreExecutor |
| 采样 | 直接采样 | 拒绝采样 |
| 内存占用 | 较低 | 较高(需额外草稿模型 + KV Cache) |
| 延迟特点 | 单步延迟低,整体吞吐量较低 | 单步延迟较高,吞吐量更高 |
MTP 级联模型详解
核心概念
MTP (Multi-Token Prediction) 通过级联多个轻量级模块逐步预测后续令牌,与传统"单次生成 N 个 token"的方式不同。
工作流程
Step 1: MTP Module 1
Input: hidden states ──> MTP1 ──> [t1], [h1_next]
Step 2: MTP Module 2
Input: h1_next ──> MTP2 ──> [t2], [h2_next]
... 重复 N 个模块 ...
Score Verification:
[proposed tokens] ──> Main Model ──> [verified probs]
Sample:
[proposed tokens, probs] ──> Rejection Sampling ──> [accepted tokens]
关键组件
1. MTPExecutor
// 每个模块有独立的执行器
std::vector<std::shared_ptr<NormalExecutor>> mtp_executors_;
// 级联执行
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream); // 状态右移
}
mtp_executors_[i]->process(propose_streams); // 使用第 i 个模块
}
2. MTPStream
管理级联状态,包括:
current_step_: 当前处于哪个 MTP 模块last_hidden_states_: 上一级输出的隐藏状态mtp_token_index_: 已通过 MTP 处理的 token 索引
3. MTPBatchStreamProcessor
处理批量输入,将所有流的隐藏状态拼接后送入模型:
// 收集所有流的隐藏状态
for (auto& stream : all_streams) {
all_hidden_tokens_num += stream->currentExecuteTokenSize();
}
// 拼接所有流的隐藏状态
device_->multiMergeCopy(params); // 多流合并优化
model_input.last_hidden_states = all_hidden_states;
性能指标
SpeculativeEngine 收集了详细的性能指标:
struct SpeculativeEngineStepMetrics {
int64_t propose_time_us; // 草稿模型推理时间
int64_t score_time_us; // 主模型验证时间
int64_t sampler_time_us; // 采样器时间
int64_t propose_token_num; // 提案 token 数
int64_t accept_token_num; // 接受 token 数
int64_t stream_num; // 处理流数
};
关键指标:Accept Rate
Accept Rate = accept_token_num / propose_token_num
- Accept Rate > 1:草稿模型预测准确,实际接受超过提案数
- Accept Rate ≈ 1:草稿模型预测与主模型一致
- Accept Rate < 1:草稿模型预测偏差较大
使用建议
| 场景 | 推荐引擎 | 原因 |
|---|---|---|
| 实时对话(低延迟优先) | NormalEngine | 单步延迟更低 |
| 批量文本生成(吞吐量优先) | SpeculativeEngine | 通过并行验证提高吞吐 |
| 显存受限环境 | NormalEngine | 无需额外草稿模型 |
| 有优质草稿模型可用 | SpeculativeEngine | 能最大化利用草稿模型加速 |
| Prefill/Decode 分离部署 | SpeculativeEngine (MTP) | 支持 Prefill/Decode 优化 |
总结
RTP-LLM 的推测解码引擎实现了一套完整的推测解码框架:
- 多种算法支持:Vanilla、MTP、Eagle 三种方式
- 灵活调度:支持动态禁用推测解码
- 性能优化:批量处理、多流合并、KV Cache 管理
- 生产就绪:广泛应用于阿里内部多个业务场景
这套实现为理解推测解码技术提供了一个优秀的工程参考。