Skip to main content

slim中权重更新优化杂记

所有的优化过程参考博客,记录了cudaipc、flatten tensor等训练引擎将权重更新到sglang推理引擎的优化过程。 https://hebiao064.github.io/rl-weight-sync-chinese

sglang支持Flatten Tensor Update Weights

pr:https://github.com/sgl-project/sglang/pull/8079

核心目标

优化 SGLang 中 RL 场景(如 RLHF)的模型权重更新性能,通过将多个小 tensor 打平(flatten)为一个大 tensor 来减少传输和处理的开销。

主要改动

新增 FlattenedTensorBucket 数据结构

  • FlattenedTensorMetadata:dataclass,记录每个 tensor 在打平后的 name、shape、dtype、start/end 索引。
  • FlattenedTensorBucket:核心类,负责:
    • 打平:将多个 (name, tensor) 拼接为一个连续的 1D tensor。
    • 重建:根据 metadata 从打平 tensor 中 slice + reshape 还原出原始 tensor,必要时做 dtype 转换。

关键点:当使用 flattened_bucket 格式时,跳过了 MultiprocessingSerializer 的序列化,直接传递已经打平的数据。

if load_format == "flattened_bucket":
serialized_named_tensors = named_tensors # 直接传递,不走 pickle+base64
else:
serialized_named_tensors = [
MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size)
]

性能优化思路

方面原方案 (逐 tensor)新方案 (flattened bucket)
传输次数多次小 tensor 传输一次大 tensor 传输
序列化pickle + base64 每个 tensor跳过 pickle/base64
内存布局分散连续内存,对 GPU 更友好

为 RL 训练框架(如 VERL)与 SGLang 推理引擎之间的权重同步提供一条更高效的数据通路:把多个 tensor 预先打平为单个连续 tensor + metadata,避免逐 tensor 序列化的开销,减少传输次数。

sglang序列化tensor

这个 MultiprocessingSerializer 类的设计:

核心职责

这是一个序列化/反序列化工具类,专门用于多进程间安全地传递 Python 对象。两个静态方法构成了完整的序列化-反序列化流程。

序列化 (serialize)

使用 ForkingPickler(而非标准 pickle)进行序列化。ForkingPickler 是 multiprocessing 模块提供的,能正确处理跨进程共享的资源(如文件描述符、共享内存等)。 支持两种输出格式: 原始 bytes(默认),适合进程间直接传输。 base64 字符串(output_str=True),适合需要文本编码的场景(如 JSON、HTTP 传输)。 反序列化 (deserialize) 自动检测输入是 str 还是 bytes,如果是字符串则先做 base64 解码。 关键设计:不使用标准 pickle.loads,而是使用自定义的 SafeUnpickler 进行反序列化。

安全机制 (SafeUnpickler)

这是整个设计的重点,采用白名单 + 黑名单双重防护来防止 pickle 反序列化攻击(引用了 CVE-2025-10164):

黑名单优先(DENY_CLASSES):明确拦截 eval、exec、os.system、subprocess.Popen 等高危类/函数。 白名单过滤(ALLOWED_MODULE_PREFIXES):只允许来自受信任模块前缀的类被加载,包括:

  • Python 标准库(builtins、collections 等)
  • PyTorch 相关(torch.*)
  • multiprocessing 相关
  • SGLang 自身模块 默认拒绝:不在白名单中的模块一律拒绝加载。

这是一个典型的安全序列化网关模式:将 pickle 的强大能力限制在可控范围内,既保留了多进程传输复杂对象(如 PyTorch tensor、模型参数)的便利性,又通过 SafeUnpickler 的 find_class 重写堵住了 pickle 反序列化的远程代码执行风险。整体设计以纯静态方法 + 无状态的方式提供服务,使用简洁。