Skip to main content

Gemma 4 模型架构解析

本文基于 vLLM-Kunlun 项目中 Gemma 4 模型的适配实践,剖析其架构创新点、适配难点、Paged Attention 机制。


目录

  1. Gemma 4 架构全景
  2. 十大适配挑战与解决方案
  3. 深度解析:PLE 每层嵌入
  4. 深度解析:YOCO KV 共享
  5. 深度解析:Reasoning 推理通道
  6. vLLM Paged Attention 设计原理

1. Gemma 4 架构全景

1.1 与其他模型的对比

特性Llama 3Gemma 2Gemma 3nGemma 4
注意力类型统一 global统一 global混合 sliding/global混合 sliding/global
注意力头维度统一统一统一sliding 和 global 不同
MoE有(可选)
每层嵌入有(硬+软)有(简化 2 层设计)
KV 共享有(YOCO)有(YOCO)
k_eq_v有(laptop 变体)
多模态有(视觉+音频+视频)
RoPE 类型标准标准标准比例式(proportional)

1.2 核心架构图

┌─────────────────────────────────────────────────────────────────┐
│ Gemma4ForConditionalGeneration │
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────────┐ │
│ │ Vision │ │ Audio │ │ Language Model │ │
│ │ Tower │ │ Tower │ │ (vLLM Optimized) │ │
│ │ (SigLIP) │ │ (Whisper) │ │ │ │
│ └──────┬───────┘ └──────┬───────┘ │ ┌────────────────┐ │ │
│ │ │ │ │ SelfDecoder │ │ │
│ ┌──────┴─────────────────┴───────┐ │ │ Layers 0..K-1 │ │ │
│ │ Multimodal Embedder │ │ │ (有独立 K/V) │ │ │
│ │ Linear + RMSNorm(no weight) │ │ └───────┬────────┘ │ │
│ └──────────────┬─────────────────┘ │ │ PLE │ │
│ │ │ ┌───────┴────────┐ │ │
│ ▼ │ │ CrossDecoder │ │ │
│ inputs_embeds │ │ Layers K..N-1 │ │ │
│ │ │ (共享前面 K/V) │ │ │
│ │ └────────────────┘ │ │
│ └──────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘

2. 十大适配挑战与解决方案

挑战 1:混合注意力架构(Sliding + Global 并存)

Gemma 4 的 config.layer_types 是一个列表,标记每层是 "sliding_attention" 还是 "full_attention"。不同类型使用不同的参数(头维度、RoPE theta、窗口大小)。

解决方案:每一层根据 layer_type 动态决定参数。

# gemma4.py:325-344
layer_type = config.layer_types[layer_idx]
self.is_sliding = layer_type == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None

# 不同的层类型使用不同的 RoPE 参数
if layer_type in config.rope_parameters:
rope_parameters = dict(config.rope_parameters[layer_type])

挑战 2:不同的头维度(Sliding vs Global)

Global attention 层使用 global_head_dim(通常更大),Sliding 层使用 head_dim。这打破了 vLLM 的"每模型统一 head_dim"假设。

解决方案:每层独立构造 Attention,各自使用正确的 head_dim

# gemma4.py:449-455
if self.is_full_attention:
head_dim = getattr(config, "global_head_dim", config.head_dim)
else:
head_dim = config.head_dim

挑战 3:k_eq_v 变体(无 v_proj)

Laptop 变体中 global attention 层没有单独的 v_proj——K 被复用为 V。

解决方案:权重加载时将 k_proj 克隆到 v_proj 位置。

# gemma4.py:1589-1599
if "self_attn.k_proj" in name and k_eq_v_layer_indices:
yield name, weight
yield name.replace("k_proj", "v_proj"), weight.clone()

挑战 4:MoE 自定义路由

Gemma 4 的 MoE 路由不是标准的分组 top-k softmax,而是:

  1. Router 预处理:RMSNorm(无权重) → × hidden_size⁻⁰·⁵ → × 每维度 learnable scale
  2. Softmax over 所有 experts → top-k 选择 → renormalize
  3. per_expert_scale 折叠到路由权重中以保持 FusedMoE 数值正确性

解决方案:自定义 routing_function 传入 FusedMoE

# gemma4.py:209-231
def routing_function(hidden_states, gating_output, topk, renormalize):
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = F.one_hot(topk_ids, num_classes=gating_output.size(-1)).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
dispatch_weights = gate_weights / renorm_factor
# 折叠 per_expert_scale 到 routing weights
topk_weights = dispatch_weights.gather(1, topk_ids)
topk_weights = topk_weights * per_expert_scale[topk_ids]
return topk_weights, topk_ids

挑战 5:比例式 RoPE(Proportional RoPE)

Gemma 4 的 RoPE 频率计算分母是 head_size 而非 rotary_dim,非旋转维度需要 zero-padding(恒等变换)。

解决方案:自定义 _compute_inv_freq,设置 rotary_dim=head_size 但仅计算实际旋转维度的频率。

# gemma4_rope.py:56-76
freq_exponents = torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
inv_freq = 1.0 / (base ** freq_exponents)
if self.nope_angles > 0:
inv_freq = torch.cat([inv_freq, torch.zeros(self.nope_angles, dtype=torch.float)])

挑战 6:PLE(Per-Layer Embedding)

详见 第 3 节

挑战 7:YOCO KV 共享

详见 第 4 节

挑战 8:Reasoning 通道解析

详见 第 5 节

挑战 9:自定义工具调用语法

Gemma 4 使用非 JSON 格式:<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}

解决方案:实现完整的状态机解析器 + 流式"累积→解析→差分"策略。

挑战 10:多模态集成

Video 无独立塔——视频帧分解为带时间戳的图像,通过 Vision Tower 以 max_soft_tokens=70 处理。


3. 深度解析:PLE 每层嵌入

3.1 设计动机

标准 Transformer 中,每个 token 只有一个 embedding。PLE 为每一层提供独立的 embedding 贡献,通过门控机制注入。这相当于每一层都能"看到"输入 token 的不同方面。

3.2 数据流五阶段

阶段 1: 主嵌入
embed_tokens(input_ids) × √H → hidden_states (batch, H)

阶段 2: PLE 嵌入查询
embed_tokens_per_layer(input_ids) × √D → reshape → (batch, L, D)
// embed_tokens_per_layer 形状是 [V_ple, L×D]

阶段 3: 投影 + 组合
Linear(hidden_states): H → L×D → ×H⁻⁰·⁵ → reshape(B, L, D) → RMSNorm(D)
(projection + per_layer_embeddings) × 1/√2 → (B, L, D)

阶段 4: 逐层切片
在 _run_decoder_layers 中,每层取 per_layer_inputs[:, layer_idx, :]

阶段 5: 门控注入
gate = GELU(Linear(hidden_states, H→D)) // 门控
contribution = Linear(gate × per_layer_input, D→H) // 投影回 H
contribution = RMSNorm(contribution)
hidden_states = hidden_states + contribution // 残差加法

3.3 关键代码

# 模型级: 投影 + 组合
# gemma4.py:751-777
def project_per_layer_inputs(self, inputs_embeds, per_layer_inputs):
proj = self.per_layer_model_projection(inputs_embeds) # H → L×D
proj = proj * self.per_layer_projection_scale # × H⁻⁰·⁵
proj = proj.reshape(batch, L, D)
proj = self.per_layer_projection_norm(proj) # RMSNorm

# 组合: (投影 + PLE嵌入) / √2
return (proj + per_layer_inputs) * self.per_layer_input_scale

# 层内: 门控注入
# gemma4.py:632-640
gate = self.per_layer_input_gate(hidden_states) # H → D
gate = F.gelu(gate, approximate="tanh")
gated = gate * per_layer_input
contribution = self.per_layer_projection(gated) # D → H
hidden_states = hidden_states + self.post_per_layer_input_norm(contribution)

4. 深度解析:YOCO KV 共享

4.1 架构分割

总层数 N=26, num_kv_shared_layers K=8

层 0 .. 层 17: SelfDecoder — 有独立 K/V 投影和缓存
层 18 .. 层 25: CrossDecoder — 复用前面层的 K/V 缓存

4.2 三层协同机制

KV 共享不是模型层的魔法,而是模型层 + Runner + 注意力后端三层配合的结果:

第 1 层:模型初始化(gemma4.py)

每个共享层的 Gemma4Attention 计算目标层名称并传入 vLLM 的 Attention

# gemma4.py:346-373
kv_sharing_target_layer_name = (
f"{prefix_root}.layers.{kv_shared_layer_index}.self_attn.attn"
)
self.attn = Attention(
...
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
)

第 2 层:Runner 初始化(gpu_model_runner.py)

Runner 遍历所有 Attention 层,对共享层:

  • 跳过 KVCacheSpec 创建(不分配显存)
  • 将其 KV cache 张量别名为目标层的张量
# gpu_model_runner.py:4036-4045
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue # ← 不分配 KV cache!

# gpu_model_runner.py:3913-3918
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
# ↑ 同一个 Python tensor 对象!不是拷贝!

第 3 层:注意力后端(flash_attn.py)

共享层跳过 KV cache 写入,直接读目标层已写入的缓存:

# flash_attn.py:488-506
if (self.kv_sharing_target_layer_name is None # ← 共享层条件为 False,跳过!
and key is not None and value is not None):
reshape_and_cache_flash(key, value, key_cache, value_cache, ...)

# flash_attn.py:525-528
flash_attn_varlen_func(
q=query,
k=key_cache, # ← 始终从 key_cache 读(共享层这里就是目标层的缓存)
v=value_cache, # ← 同上
block_table=block_table,
...
)

4.3 共享层前向传播的特殊处理

共享层仍然执行 qkv_proj,但 K 和 V 不参与后续处理:

# gemma4.py:412-424
if not self.is_kv_shared_layer:
k = self.k_norm(k) # 应用 K 归一化
q, k = self.rotary_emb(positions, q, k) # RoPE 作用于 Q 和 K
v = self.v_norm(v) # 应用 V 归一化
else:
q = self.rotary_emb(positions, q, k)[0] # 只对 Q 做 RoPE!
# K 和 V 不变——后续传入 Attention.forward() 但不会被写入缓存

4.4 YOCO 快速预填充优化

# gemma4.py:1085-1168 — fast_prefill_forward
# Step 1: SelfDecoder 处理全部 B 个 token
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids, # 全部 token
positions=positions,
)

# Step 2: 只取 logit 位置(~1-10 个 token 而非全部 B 个)
num_padded = len(logits_indices_padded)
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)

# Step 3: CrossDecoder 只处理 logit 位置
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded], # 只有 logit 位置
self.hidden_states[:num_padded],
)

5. 深度解析:Reasoning 推理通道

5.1 Gemma 4 的思维链格式

<|channel>thought           ← 特殊 token 100(开始推理通道)
巴黎是法国的首都,位于欧洲西部...
所以答案是巴黎。
<channel|> ← 特殊 token 101(结束推理通道)
最终的用户可见回答:巴黎是法国的首都。

5.2 核心挑战

vLLM 默认 skip_special_tokens=True删除 <|channel><channel|>,导致推理内容和最终回答混在一起无法分离。

5.3 五层防御体系

第 1 层:adjust_request — 强制 skip_special_tokens=False
第 2 层:猴子补丁 — 兜底:修补 to_sampling_params
第 3 层:分隔符分割 — 主路径:在 <|channel> / <channel|> 上直接分割
第 4 层:标签正则回退 — 降级:匹配 thought*\n 正则
第 5 层:流式 Token 状态机 — 流式路径:基于 token ID 的状态转换

5.4 非流式主路径

# gemma4_reasoning_parser.py:152-196
def extract_reasoning(self, model_output, request):
if "<|channel>" in model_output:
prefix, _, after_start = model_output.partition("<|channel>")
if "<channel|>" in after_start:
reasoning_text, _, content_text = after_start.partition("<channel|>")
reasoning_text = _strip_thought_label(reasoning_text) # 去掉 "thought\n"
return (reasoning_text, prefix + content_text)

5.5 流式 Token-ID 驱动状态机

状态转换:
start_token_id(100) in delta
┌──────────┐ ──────────────────────────→ ┌─────────────┐
│ CONTENT │ │ REASONING │
│ (初始态) │ ←─────────────────────────── │ (思维链) │
└──────────┘ end_token_id(101) in delta └─────────────┘

关键难点:thought\n 标签可能被分成多个 token
"thought" + "en" + "-" + "US" + "\n"

解决方案:缓冲累积 → 正则匹配 → 截断发射
delta "thou" → 缓冲="thou", 无匹配, 不发射
delta "ght\n" → 缓冲="thought\n", 匹配! strip → 不发射
delta "巴黎是" → 已剥离, 发射 reasoning="巴黎是"

6. vLLM Paged Attention 设计原理

6.1 核心思想

借鉴操作系统的虚拟内存分页,将 KV Cache 切分为固定大小的 block(如 64 token/block),通过页表(block_table)将逻辑位置翻译为物理地址。

6.2 四个核心抽象

vLLM 概念OS 类比数据结构
KVCacheBlock物理页帧{block_id, ref_cnt, block_hash}
BlockPool物理内存管理器free_block_queue (LRU 双向链表)
BlockTable页表int32[max_reqs, max_blocks]
slot_mappingMMU 翻译后的物理地址int64[num_tokens]

6.3 关键公式

# block_table.py:76-113 — compute_slot_mapping
slot = block_table[request_idx][token_position // block_size] * block_size
+ token_position % block_size

# 示例: block_size=64, token_position=130
# block_id = block_table[row][130 // 64] = block_table[row][2]
# offset = 130 % 64 = 2
# slot = block_id * 64 + 2

6.4 写入路径(reshape_and_cache)

# flash_attn.py:488-506 + _custom_ops.py:1605-1618
# CUDA kernel 伪代码:
for i in range(num_tokens):
slot = slot_mapping[i]
block_id = slot // BLOCK_SIZE
offset = slot % BLOCK_SIZE
key_cache[block_id, offset, :, :] = key[i]
value_cache[block_id, offset, :, :] = value[i]

6.5 读取路径(Paged Attention)

# flash_attn.py:525-528
flash_attn_varlen_func(
q=query,
k=key_cache, # ← 整个 KV cache 张量
block_table=block_table, # ← FlashAttention 内核通过页表跳转寻址!
)

6.6 Prefix Caching

相同前缀的请求共享同一组物理块:

# block_pool.py:162-186
def get_cached_block(self, block_hash, kv_cache_group_ids):
"""通过 block_hash 查找已缓存的块"""
block = self.cached_block_hash_to_block.get_one_block(key)
# 命中 → touch(block), ref_cnt++
# 未命中 → allocate new block, cache block_hash

好处:多个请求共享 system prompt 时,KV Cache 只计算一次。


参考资料