Skip to main content

RTP-LLM Speculative Engine 深度解析

基于代码分析的推测解码引擎实现详解

项目概览

RTP-LLM 是阿里基础模型推理团队开发的大语言模型推理加速引擎,广泛应用于阿里巴巴集团内部的多个业务场景(淘宝、天猫、闲鱼、菜鸟等)。其中,speculative_engine 模块实现了推测解码技术,通过"草稿模型预测 + 主模型验证"的方式显著提升推理吞吐量。


推测解码的核心思想

传统的大模型推理每步只能生成 1 个 token,而推测解码的核心思路是:

  1. 草稿模型快速预测:使用轻量级的草稿模型一次性预测 N 个后续 token
  2. 主模型并行验证:主模型并行对这 N 个 token 进行验证
  3. 拒绝采样:根据草稿模型和主模型的概率分布,决定接受多少 token
传统解码:    [Input] -> Main Model -> [token1] -> [Input+token1] -> Main Model -> [token2] -> ...
推测解码: [Input] -> Draft Model -> [t1, t2, ..., tN] -> Main Model -> [p(t1), p(t2), ..., p(tN)]
-> Sampler -> [accepted tokens]

模块架构

speculative_engine/
├── SpeculativeEngine.h/cc # 主引擎类
├── SpeculativeScheduler.h/cc # 调度器
├── propose_executor/ # 提案执行器(草稿模型)
│ ├── ProposeExecutor.h/cc # 基类
│ ├── VanillaExecutor.h/cc # Vanilla 实现
│ ├── MTPExecutor.h/cc # MTP 实现
│ ├── EagleExecutor.h/cc # Eagle 实现
│ └── MTPBatchStreamProcessor.h/cc # MTP 批量处理器
├── score_executor/ # 评分执行器(主模型)
│ ├── ScoreExecutor.h/cc
│ └── ScoreBatchStreamProcessor.h/cc
└── speculative_sampler/ # 推测采样器
└── SpeculativeSampler.h/cc

三种推测解码类型

1. Vanilla 推测解码

传统的推测解码方式,草稿模型一次性生成 N 个 token。

// SpeculativeEngine::spStep()
THROW_IF_STATUS_ERROR(propose_executor_->propose(streams)); // 生成 N 个 token
THROW_IF_STATUS_ERROR(score_executor_->score(streams)); // 并行验证 N 个 token
CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(streams)); // 采样

2. MTP 推测解码

Multi-Token Prediction 级联结构,通过多个轻量级模块逐步预测。

// MTPExecutor::propose()
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream); // 状态右移
}
mtp_executors_[i]->process(propose_streams); // 每个模块执行一次
}

MTP 模型结构

// MTPModel::embeddingPost()
auto e_norm = device_->layernorm(LayernormParams(hidden_states, ..., *weights_.layers[0].enorm));
auto h_norm = device_->layernorm(LayernormParams(last_hidden_states, ..., *weights_.layers[0].hnorm));
torch::Tensor cat_tensor = torch::cat({h_norm_tensor, e_norm_tensor}, -1);
auto final_hidden_states = device_->gemm({*cat_buffer, *(weights_.layers[0].eh_proj->kernel)});

3. Eagle 推测解码

与 MTP 类似,但使用同一个模块循环执行。

// EagleExecutor::propose()
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream);
}
mtp_executors_[0]->process(propose_streams); // 始终使用第一个模块
}

Eagle3 推测解码(最新变体)

Eagle3 是 Eagle 的增强版本,在融合机制上做了改进。

// Eagle3Model::embeddingPost()
if (last_hidden_states == nullptr) {
// 首次执行,复制 hidden states(Q_proj 需要 2*hidden_size)
BufferPtr duplicate_hidden = device_->clone({*torchTensor2Buffer(
torch::cat({Buffer2torchTensor(hidden_states, false),
Buffer2torchTensor(hidden_states, false)}, 1)),
AllocationType::DEVICE});
return {duplicate_hidden, hidden_states};
}

// 投影上一级隐藏状态
auto proj_last_hidden_states = device_->gemm({*inputs.last_hidden_states,
*(weights_.layers[0].eagle3_fc_proj->kernel)});

// 归一化
auto proj_norm = device_->layernorm(LayernormParams(proj_last_hidden_states, ...,
*weights_.layers[0].eagle3_fc_norm, ...));
auto input_norm = device_->layernorm(LayernormParams(hidden_states, ...,
*weights_.layers[0].eagle3_input_norm, ...));

// 拼接
torch::Tensor cat_tensor = torch::cat({input_norm_tensor, proj_norm_tensor}, -1);

MTP vs Eagle vs Eagle3 对比

执行流程对比

特性MTPEagleEagle3
执行器数量N 个(每个模块独立)1 个(循环使用)1 个(循环使用)
执行代码mtp_executors_[i]mtp_executors_[0]mtp_executors_[0]
模块复用每步使用不同模块同一模块重复执行同一模块重复执行

模型结构对比

┌─────────────────────────────────────────────────────────────────┐
│ MTP Layer │
├─────────────────────────────────────────────────────────────────┤
│ current embedding ──> LayerNorm(enorm_weights) ──┐ │
│ │ │
│ previous hidden ──> LayerNorm(hnorm_weights) ──┤ │
│ │ │
│ ├─> Concat ─> GEMM(eh_proj) ─> 输出
│ │ │
│ [norm + norm + concat + linear] │ │
└─────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────┐
│ Eagle3 Layer │
├─────────────────────────────────────────────────────────────────┤
│ current embedding ──> LayerNorm(input_norm_weights) ──┐ │
│ │ │
│ previous hidden ──> GEMM(fc_proj) ─> LayerNorm(fc_norm_weights) │
│ │ │
│ ├─> Concat ─> 输出
│ │ │
│ [norm + proj + norm + concat] │ │
└─────────────────────────────────────────────────────────────────┘

关键权重对比

权重MTPEagle3说明
enorm / input_norm当前 embedding 的归一化
hnorm-上一级隐藏状态的归一化
fc_proj-上一级隐藏状态的投影
fc_norm-投影结果的归一化
eh_proj-MTP 融合投影矩阵

核心差异

  1. Hidden State 处理方式

    • MTP: 直接对 last_hidden_states 做 LayerNorm
    • Eagle3: 先通过 GEMM 投影,再做 LayerNorm(增强表达能力)
  2. 融合路径

    • MTP: [norm_e, norm_h] → concat → linear
    • Eagle3: [norm_input, norm_proj] → concat
  3. 初始状态

    • MTP: 无特殊处理
    • Eagle3: 首次执行时复制 hidden states(因为 Q_proj 需要 2 * hidden_size

设计理念差异

方面MTPEagle3
理念简单直接的双路融合通过投影增强表达
参数量较少(仅一个 eh_proj)略多(多一个 fc_proj)
计算复杂度2 个 LayerNorm + 1 GEMM2 个 LayerNorm + 2 GEMM
适用场景资源受限、追求简单追求更高接受率

三种类型总结

类型执行器数量模型结构特点
Vanilla1 个标准 GPT一次性生成 N 个 token
MTPN 个MTP (norm+norm+linear)每步独立模块,级联预测
Eagle1 个MTP单模块循环使用
Eagle31 个Eagle3 (norm+proj+norm)增强融合能力

拒绝采样算法

SpeculativeSampler 实现了拒绝采样算法,决定接受多少草稿 token:

// SpeculativeSampler::stochasticSample()
for (int i = 0; i < propose_step; i++) {
float p_draft = propose_all_probs[proposed_token_id];
float p_score = score_all_probs[proposed_token_id];

float accept_prob = p_score / p_draft; // 接受概率
float random = Uniform(0, 1);

if (random <= accept_prob) {
// 接受草稿 token
accepted_token = proposed_token_id;
accept_len++;
} else {
// 拒绝,从修正分布重新采样
new_dist = max(0, p_score - p_draft);
new_dist = new_dist / new_dist.sum();
accepted_token = sample(new_dist);
break; // 停止接受后续 token
}
}

关键点

  • 接受概率 p_score / p_draft:草稿模型和主模型概率越接近,接受率越高
  • 当拒绝时,从主模型的剩余分布中采样
  • 拒绝后必须停止,确保输出分布与主模型一致

NormalEngine vs SpeculativeEngine 对比

对比维度NormalEngineSpeculativeEngine
核心思想单模型串行推理草稿模型预测 + 主模型验证
模型数量1 个主模型1 个草稿模型 + 1 个主模型
执行流程每步生成 1 个 token草稿模型生成 N 个 token,主模型并行验证
Executor单个 NormalExecutorProposeExecutor + ScoreExecutor
采样直接采样拒绝采样
内存占用较低较高(需额外草稿模型 + KV Cache)
延迟特点单步延迟低,整体吞吐量较低单步延迟较高,吞吐量更高

MTP 级联模型详解

核心概念

MTP (Multi-Token Prediction) 通过级联多个轻量级模块逐步预测后续令牌,与传统"单次生成 N 个 token"的方式不同。

工作流程

Step 1: MTP Module 1
Input: hidden states ──> MTP1 ──> [t1], [h1_next]

Step 2: MTP Module 2
Input: h1_next ──> MTP2 ──> [t2], [h2_next]

... 重复 N 个模块 ...

Score Verification:
[proposed tokens] ──> Main Model ──> [verified probs]

Sample:
[proposed tokens, probs] ──> Rejection Sampling ──> [accepted tokens]

关键组件

1. MTPExecutor

// 每个模块有独立的执行器
std::vector<std::shared_ptr<NormalExecutor>> mtp_executors_;

// 级联执行
for (size_t i = 0; i < propose_step_; i++) {
if (i > 0) {
mtp_stream->shiftRightOneToken(*stream); // 状态右移
}
mtp_executors_[i]->process(propose_streams); // 使用第 i 个模块
}

2. MTPStream

管理级联状态,包括:

  • current_step_: 当前处于哪个 MTP 模块
  • last_hidden_states_: 上一级输出的隐藏状态
  • mtp_token_index_: 已通过 MTP 处理的 token 索引

3. MTPBatchStreamProcessor

处理批量输入,将所有流的隐藏状态拼接后送入模型:

// 收集所有流的隐藏状态
for (auto& stream : all_streams) {
all_hidden_tokens_num += stream->currentExecuteTokenSize();
}

// 拼接所有流的隐藏状态
device_->multiMergeCopy(params); // 多流合并优化

model_input.last_hidden_states = all_hidden_states;

性能指标

SpeculativeEngine 收集了详细的性能指标:

struct SpeculativeEngineStepMetrics {
int64_t propose_time_us; // 草稿模型推理时间
int64_t score_time_us; // 主模型验证时间
int64_t sampler_time_us; // 采样器时间
int64_t propose_token_num; // 提案 token 数
int64_t accept_token_num; // 接受 token 数
int64_t stream_num; // 处理流数
};

关键指标:Accept Rate

Accept Rate = accept_token_num / propose_token_num
  • Accept Rate > 1:草稿模型预测准确,实际接受超过提案数
  • Accept Rate ≈ 1:草稿模型预测与主模型一致
  • Accept Rate < 1:草稿模型预测偏差较大

使用建议

场景推荐引擎原因
实时对话(低延迟优先)NormalEngine单步延迟更低
批量文本生成(吞吐量优先)SpeculativeEngine通过并行验证提高吞吐
显存受限环境NormalEngine无需额外草稿模型
有优质草稿模型可用SpeculativeEngine能最大化利用草稿模型加速
Prefill/Decode 分离部署SpeculativeEngine (MTP)支持 Prefill/Decode 优化

总结

RTP-LLM 的推测解码引擎实现了一套完整的推测解码框架:

  1. 多种算法支持:Vanilla、MTP、Eagle 三种方式
  2. 灵活调度:支持动态禁用推测解码
  3. 性能优化:批量处理、多流合并、KV Cache 管理
  4. 生产就绪:广泛应用于阿里内部多个业务场景

这套实现为理解推测解码技术提供了一个优秀的工程参考。