slim中权重更新优化杂记
所有的优化过程参考博客,记录了cudaipc、flatten tensor等训练引擎将权重更新到sglang推理引擎的优化过程。 https://hebiao064.github.io/rl-weight-sync-chinese
sglang支持Flatten Tensor Update Weights
pr:https://github.com/sgl-project/sglang/pull/8079
核心目标
优化 SGLang 中 RL 场景(如 RLHF)的模型权重更新性能,通过将多个小 tensor 打平(flatten)为一个大 tensor 来减少传输和处理的开销。
主要改动
新增 FlattenedTensorBucket 数据结构
FlattenedTensorMetadata:dataclass,记录每个 tensor 在打平后的 name、shape、dtype、start/end 索引。FlattenedTensorBucket:核心类,负责:- 打平:将多个
(name, tensor)拼接为一个连续的 1D tensor。 - 重建:根据 metadata 从打平 tensor 中 slice + reshape 还原出原始 tensor,必要时做 dtype 转换。
- 打平:将多个
关键点:当使用 flattened_bucket 格式时,跳过了 MultiprocessingSerializer 的序列化,直接传递已经打平的数据。
if load_format == "flattened_bucket":
serialized_named_tensors = named_tensors # 直接传递,不走 pickle+base64
else:
serialized_named_tensors = [
MultiprocessingSerializer.serialize(named_tensors)
for _ in range(self.server_args.tp_size)
]
性能优化思路
| 方面 | 原方案 (逐 tensor) | 新方案 (flattened bucket) |
|---|---|---|
| 传输次数 | 多次小 tensor 传输 | 一次大 tensor 传输 |
| 序列化 | pickle + base64 每个 tensor | 跳过 pickle/base64 |
| 内存布局 | 分散 | 连续内存,对 GPU 更友好 |
为 RL 训练框架(如 VERL)与 SGLang 推理引擎之间的权重同步提供一条更高效的数据通路:把多个 tensor 预先打平为单个连续 tensor + metadata,避免逐 tensor 序列化的开销,减少传输次数。
sglang序列化tensor
这个 MultiprocessingSerializer 类的设计:
核心职责
这是一个序列化/反序列化工具类,专门用于多进程间安全地传递 Python 对象。两个静态方法构成了完整的序列化-反序列化流程。
序列化 (serialize)
使用 ForkingPickler(而非标准 pickle)进行序列化。ForkingPickler 是 multiprocessing 模块提供的,能正确处理跨进程共享的资源(如文件描述符、共享内存等)。 支持两种输出格式: 原始 bytes(默认),适合进程间直接传输。 base64 字符串(output_str=True),适合需要文本编码的场景(如 JSON、HTTP 传输)。 反序列化 (deserialize) 自动检测输入是 str 还是 bytes,如果是字符串则先做 base64 解码。 关键设计:不使用标准 pickle.loads,而是使用自定义的 SafeUnpickler 进行反序列化。
安全机制 (SafeUnpickler)
这是整个设计的重点,采用白名单 + 黑名单双重防护来防止 pickle 反序列化攻击(引用了 CVE-2025-10164):
黑名单优先(DENY_CLASSES):明确拦截 eval、exec、os.system、subprocess.Popen 等高危类/函数。 白名单过滤(ALLOWED_MODULE_PREFIXES):只允许来自受信任模块前缀的类被加载,包括:
- Python 标准库(builtins、collections 等)
- PyTorch 相关(torch.*)
- multiprocessing 相关
- SGLang 自身模块 默认拒绝:不在白名单中的模块一律拒绝加载。
这是一个典型的安全序列化网关模式:将 pickle 的强大能力限制在可控范围内,既保留了多进 程传输复杂对象(如 PyTorch tensor、模型参数)的便利性,又通过 SafeUnpickler 的 find_class 重写堵住了 pickle 反序列化的远程代码执行风险。整体设计以纯静态方法 + 无状态的方式提供服务,使用简洁。