Skip to main content

sglang pp并行详解

本文以sglang版本:v0.4.9.post2版本中,tp_worker.py的forward_batch_generation函数为切入点,逐步阅读代码,探究sglang中pp流水线并行推理如何实现的?

forward_batch_generation推理过程

  • 根据不同的pp层不同的处理获取输入:非首层获取pp_proxy_tensors输入;第一个 rank 直接从输入 token 开始计算,不需要上游数据
  • model_runner.forward,调用模型推理
  • model_runner.sample,pp最后一层采样获取next_token_ids
  • 返回结果
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)

pp_proxy_tensors = None
# Pipeline Parallelism(流水线并行)中非首 rank 从上游 rank 接收中间计算结果
# 推理前接受pp_proxy_tensors,用于后续的推理
# 只有非首 rank 才执行接收。PP 的第一个 rank 直接从输入 token 开始计算,不需要上游数据;而后续 rank(如第 2、3 阶段)需要接收前一阶段的 hidden states。
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
# self.pp_group.recv_tensor_dict(...) — 从 PP 组的上一个 rank(src = rank - 1)接收一个张量字典。具体流程(见 parallel_state.py):
self.pp_group.recv_tensor_dict(
# 这是一个关键优化:send-allgather 模式。PP 上游 rank 只发送张量的 1/tp_size 切片,接收后在 TP 组内做 all_gather 还原完整张量。这样可以将 PP 跨节点通信量减少到 1/tp_size,显著降低带宽开销
all_gather_group=self.get_attention_tp_group()
)
)

# 如果是最后一个 rank:
# 执行模型前向计算,得到 logits_output
# 设置 launch_done 事件通知调用方
# 如果 skip_sample=False,对 logits 进行采样得到 next_token_ids
# 返回 logits、token ids 和 cuda graph 标志
if self.pp_group.is_last_rank:
logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if launch_done is not None:
launch_done.set()

if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
)

return logits_output, next_token_ids, can_run_cuda_graph
# 如果是中间 rank:
# 执行模型前向计算,得到中间张量 pp_proxy_tensors
# 返回中间张量(不采样,因为还没到最后一层)
else:
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
PP Rank 0 (首 rank)          PP Rank 1 (非首 rank)
────────────────── ─────────────────────
input_ids → 模型前半层 ← recv_tensor_dict (本步骤)
↓ ↓
send_tensor_dict → 中间张量 → 模型后半层

logits → sample → token

recv_tensor_dict过程

self.pp_group.recv_tensor_dict(...) — 从 PP 组的上一个 rank(src = rank - 1)接收一个张量字典。具体流程如下:

  • 先接收元数据(张量的 shape、dtype 等)
  • 再通过 torch.distributed.recv 逐个接收 GPU 张量

采用pickle序列化(pickle)来序列化元数据,确保兼容性。

def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
)

group = self.device_group
metadata_group = self.cpu_group

# src 是前一个rank
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
# 先接收元数据(张量的 shape、dtype 等)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
# 再通过 torch.distributed.recv 逐个接收 GPU 张量
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue

# send-allgather: send only a slice, then do allgather.
use_all_gather = (
all_gather_group is not None
and tensor.numel() % all_gather_size == 0
)

if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
tensor = tensor.reshape(orig_shape)

tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

event_loop_pp调度

如何调用forward_batch_generation?

前面介绍的 recv_tensor_dict 函数是接收张量字典的函数,与之对应的,send_tensor_dict 函数则是发送张量字典的函数。

一起看下发送在哪儿发生的?

事情发生在scheduler.py的event_loop_pp中,当然你可以看到不同种类的调度loop函数。

# 非pd分离模式
if disaggregation_mode == DisaggregationMode.NULL:
# 就是这个,SGLang 调度器在 Pipeline Parallelism(PP)模式下的主事件循环,是非重叠(non-overlap)版本的 PP 调度逻辑。
if server_args.pp_size > 1:
scheduler.event_loop_pp()
elif scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()

函数维护 pp_size 个 micro-batch(微批次) 的状态,通过轮询方式处理每个微批次的调度、推理和结果回传。

先看第一个问题:上面的forward_batch_generation函数是怎么循环调度的?也是发生在event_loop_pp函数中的。

# 循环调度
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
...
# 根据id获取当前micro-batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
# run_batch函数中调用了forward_batch_generation函数
result = self.run_batch(self.cur_batch)
...

核心调度过程

先看一下用到的数据结构:

  • mbs[pp_size] — 当前轮每个微批次的 ScheduleBatch
  • last_mbs[pp_size] — 上一轮各微批次(用于 decode 阶段的状态延续)
  • bids[pp_size] — 各微批次的 batch ID
  • pp_outputs — 上一轮的 PP 输出张量,用于延迟发送

每轮循环的流程(按 mb_id 遍历):

┌─ 1. 接收请求 & 调度下一个 batch
│ recv_requests() → process_input_requests() → get_next_batch_to_run()

├─ 2. 执行推理
│ run_batch(cur_batch) → result

├─ 3.【最后一个 PP rank】组装输出并发送
│ 将 next_token_ids + logprob 等打包为 PPProxyTensors
│ → send_tensor_dict() 发给首 rank

├─ 4. 接收下一个微批次的输出 & 后处理
│ recv_tensor_dict() ← 接收上一轮发出的结果
│ → process_batch_result() 处理完成的请求

└─ 5.【非最后 PP rank】
① send_tensor_dict() PP 环形通信中的接力站
② point_to_point_pyobj() 转发请求给下一个 PP stage
③ send_tensor_dict() 发送 hidden states 给下一个 PP stage

send_tensor_dict过程

发送过程分两阶段:

  • 发送元数据

通过 send_object(CPU group,pickle 序列化)发送元数据。

  • 发送张量数据

遍历 tensor_list,逐个通过 torch.distributed.send 点对点发送,GPU 张量使用 NCCL device group,CPU 张量使用 Gloo CPU group,空张量跳过。

def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict

all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
)

group = self.device_group
metadata_group = self.cpu_group

if dst is None:
# dst 是 PP 组内的 下一个 rank,对最后一个 rank 来说,(N-1+1) % N = 0,即回绕到首 rank。
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"

assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
# 调用 _split_tensor_dict 将字典拆分为:
# metadata_list:每个 key 及其张量的 shape/dtype/device(非张量值直接包含)
# tensor_list:实际的张量数据
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue

# send-allgather: send only a slice, then do allgather.
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None

send-allgather 优化: 当传入 all_gather_group(即 TP 组)时:

if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

发送端只发 1/tp_size 的切片(按当前 TP rank 切分), 接收端收到切片后在 TP 组内做 all_gather 还原完整张量 效果:PP 跨阶段通信量降低为 1/tp_size,例如 TP=4 时通信量减少 75%。

转发请求

非最后 rank 向下一 PP stage 转发数据:

def event_loop_pp(self):
"""A non-overlap scheduler loop for pipeline parallelism."""
......
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
# (not last rank)
if not self.pp_group.is_last_rank:
# PP 环形通信中的接力站——接收前一个 rank 传来的最终输出,再转给下一个 rank,确保每个 PP rank 的调度器都能获得 next_token_ids 来正确管理本地状态。
# 每个 rank 收到后都会执行 process_batch_result()(L920),用 next_token_ids 更新本地调度器状态。然后在下一个迭代中,通过这段代码继续接力转发,直到所有 rank 都拿到输出。
if pp_outputs:
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)

# 转发请求对象
# 将当前 rank 收到的新请求(recv_reqs,Python 对象列表)转发给下一个 PP stage 的对应 rank
# send out reqs to the next stage
dp_offset = self.attn_dp_rank * self.attn_tp_size
# 只有 attn_tp_rank == 0 才参与——TP 组内只需要一个 rank 做点对点传输,其他 TP rank 后续通过广播获得请求。
if self.attn_tp_rank == 0:
# point_to_point_pyobj 的实现(utils.py:1145):
# 先 pickle.dumps 序列化 Python 对象
# 将字节流包装为 GPU tensor
# 通过 dist.send 发送(两阶段:先发大小,再发数据)
# 接收端反序列化还原
point_to_point_pyobj(
recv_reqs,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
)

# send out proxy tensors to the next stage
# 发送 hidden states
# 将当前 batch 前向计算产生的 中间 hidden states 发送给下一个 PP rank。
# 这就是 PP 前向传播的核心——当前 rank 计算了模型的前几层,产出 hidden states,发给下一个 rank 继续计算后续层。
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)

为什么需要转发请求? 因为只有 PP rank 0 直接从 tokenizer 接收请求(recv_requests 中只有 pp_rank == 0 从 zmq 读取)。后续 PP stage 的调度器也需要知道请求信息来管理本地状态(KV cache 分配、batch 组织等),所以需要逐级转发。

rank 地址计算(以 PP=2, TP=4, DP=2 为例):

PP0: [R0, R1, R2, R3 | R4, R5, R6, R7]
DP0-TP0..3 DP1-TP0..3
PP1: [R8, R9, R10, R11 | R12, R13, R14, R15]
DP0-TP0..3 DP1-TP0..3

R0 (pp_rank=0, attn_tp_rank=0, dp_offset=0):
src = 0*8 + 0 = 0, dst = 1*8 + 0 = 8 → R0 发给 R8
R4 (pp_rank=0, attn_tp_rank=0, dp_offset=4):
src = 0*8 + 4 = 4, dst = 1*8 + 4 = 12 → R4 发给 R12

环形传递的完整过程图解

以下是 event_loop_pp 的完整调度过程图示(以 PP=3 为例,Rank 0/1/2)。

1. 单个微批次迭代中每个 Rank 的执行流程


2. PP 前向数据流(hidden states 正向传递)


3. next_token_ids 环形分发(核心重点)


4. 环形分发的拓扑总览

环形方向:Last → First → ... → Last,每一跳都跨一个微批次迭代(mb_id),利用微批次的交错时序来隐藏通信延迟。


总结:三条并行数据通路

  • hidden_states:正向逐级传递,驱动 PP 前向计算
  • next_token_ids:环形传递,让所有 rank 的调度器同步最终输出
  • recv_reqs:正向逐级传递,让所有 PP stage 知道要处理哪些请求

环形调度的代码细节

代码里有点绕的地方是 最终输出 next_token_ids(环形回传)。为什么需要这个操作呢?

只有最后一个 rank 拥有最终的 next_token_ids,但 所有 rank 的调度器都需要它 来执行 process_batch_result()(标记完成的请求、更新 KV cache 状态、准备下一轮 decode 的输入等)。因此需要把 next_token_ids 分发给所有 rank。

while True:
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]

# 接收请求
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
mbs[mb_id] = self.get_next_batch_to_run()
self.running_mbs[mb_id] = self.running_batch

self.cur_batch = mbs[mb_id]
if self.cur_batch:
# 调用forward_batch_generation推理过程
result = self.run_batch(self.cur_batch)

# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
if self.cur_batch.return_logprob:
pp_outputs = ...
else:
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
# 内容:next_token_ids + logprob,目的:将最终输出回传给首 rank
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)

# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
logits_output_args = {
k[len("logits_output.") :]: v
for k, v in next_pp_outputs.tensors.items()
if k.startswith("logits_output.")
}
if len(logits_output_args) > 0:
logits_output = LogitsProcessorOutput(**logits_output_args)
else:
logits_output = None
output_result = GenerationBatchResult(
logits_output=logits_output,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=next_pp_outputs.tensors.get(
"extend_input_len_per_req", None
),
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
"extend_logprob_start_len_per_req", None
),
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
if not self.pp_group.is_last_rank:
# 如上节代码所示
# 每个 rank 收到的 next_pp_outputs 会在 下一个微批次迭代 中作为 pp_outputs 被转发出去。因此这个环形传递并不是在一次迭代中完成的,而是跨多个微批次迭代逐步推进,利用了 PP 微批次交错执行的时间差来隐藏通信延迟。
pp_outputs = next_pp_outputs

可以看到在event_loop_pp函数中,调用了三次

代码位置发送者接收者内容目的
L883最后 rank首 rank(环绕)next_token_ids + logprob将最终输出回传给首 rank
L930非最后 rank下一个 rank上一轮收到的 pp_outputs环形接力,让所有 rank 拿到 next_token_ids
L948非最后 rank下一个 rankhidden statesPP 前向计算的中间结果传递

为什么需要这样的环形传递? 根本原因:每个 PP rank 上都运行着独立的调度器实例。 每个 rank 的调度器需要 next_token_ids 来完成以下操作(process_batch_result):

  • 更新请求状态:将生成的 token 追加到每个请求的 output_ids
  • 检查终止条件:判断是否遇到 EOS、达到 max_tokens 等
  • 管理 KV Cache:释放已完成请求的 cache 空间
  • 准备下一轮输入:将 next_token_ids 作为下一轮 decode 的 input_ids
  • 如果某个 rank 不知道 next_token_ids,它的调度器就无法正确管理本地状态,下一轮的 get_next_batch_to_run() 会出错。