Skip to main content

mrope

MRoPE (Multimodal Rotary Position Embedding) 原理与计算方法

1. 解决什么问题

标准 RoPE 为每个 token 分配一个一维递增位置 ID [0, 1, 2, ...],这对纯文本足够。但图像 token 本质是二维网格(高 x 宽),视频还多一个时间维度。一维位置无法表达"同一行的两个 patch 在空间上相邻"这类结构信息。

MRoPE 将位置从 1D 扩展到 3D(T, H, W) — 时间、高度、宽度,每个 token 拥有三个独立的位置 ID。

2. 位置 ID 的计算(get_rope_index

参见 python/minisgl/multimodal/mrope.py:8-110

对于一条包含文本和图像的序列,位置分配规则如下:

文本 token:三个维度获得相同的单调递增值,等价于标准 1D 位置:

T: [0, 1, 2, ...]
H: [0, 1, 2, ...]
W: [0, 1, 2, ...]

图像 token:假设图像被切成 t x h x w 的 patch 网格,每个 patch 获得其在网格中的坐标:

# mrope.py:77-94
t_index = arange(t).expand(h*w) # 时间维坐标
h_index = arange(h).expand(t, w) # 高度坐标
w_index = arange(w).expand(t, h) # 宽度坐标

例如一个 1x4x4 (1帧, 4行, 4列) 的图像,16 个 patch 的位置:

T: [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]   # 同一帧
H: [0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3] # 行号
W: [0,1,2,3, 0,1,2,3, 0,1,2,3, 0,1,2,3] # 列号

最终输出 mrope_positions 的 shape 为 (3, seq_len)

3. 旋转嵌入的应用(MRotaryEmbedding._forward_mrope

参见 python/minisgl/layers/rotary.py:139-183

标准 RoPE 用一个位置 ID 查一组 cos/sin,应用到整个 head_dim。MRoPE 的核心区别是:将 head_dim 分段,每段使用不同维度的位置

mrope_section 定义了分段方式,例如 [16, 24, 24],表示 rotary_dim/2 = 64 维中:

  • 前 16 维用 T (时间) 位置的 cos/sin
  • 中间 24 维用 H (高度) 位置的 cos/sin
  • 最后 24 维用 W (宽度) 位置的 cos/sin

计算步骤:

# 1. 用三组位置分别查 cos/sin 缓存
cos_sin = cache[positions] # (3, seq_len, rotary_dim)
cos, sin = cos_sin.chunk(2, -1) # 各 (3, seq_len, rotary_dim/2)

# 2. 按 mrope_section 分段,每段选对应维度
# split 后 m 的 shape 为 (3, seq_len, section_size)
# m[0] = T 维位置的值, m[1] = H 维, m[2] = W 维
cos = cat([m[i] for i, m in enumerate(cos.split(mrope_section, dim=-1))])
sin = cat([m[i] for i, m in enumerate(sin.split(mrope_section, dim=-1))])
# 结果: (seq_len, rotary_dim/2) — 每段取自不同空间维度

# 3. neox-style 拼倍: (seq_len, rotary_dim)
cos = cat([cos, cos], dim=-1)
sin = cat([sin, sin], dim=-1)

# 4. 标准旋转应用
q_rot = q * cos + rotate_half(q) * sin

4. 直觉理解

维度段位置来源编码的信息
前 16 维T 位置时间/帧顺序(视频帧间距离)
中 24 维H 位置垂直空间距离(同列 patch 的行距)
后 24 维W 位置水平空间距离(同行 patch 的列距)

这样,attention score 的 q·k 点积中:

  • 同一行相邻 patch:W 维位置差小 → 该段旋转角小 → 相似度高
  • 不同帧的 patch:T 维位置差大 → 该段旋转衰减明显
  • 纯文本 token:三维位置相同,退化为标准 1D RoPE

5. Decode 阶段

Decode 时每次只生成一个文本 token,三个维度使用相同位置值(device_len - 1 + position_delta),此时 positions 退化为 1D,走 flashinfer 快速路径(rotary.py:124-133),无额外开销。