它是一层调度系统,统一包装 cuDNN、CUTLASS、TRT-LLM、TGV、CuTe DSL、DeepGEMM 和特化 kernel。
1. 学习地图:把 GEMM 当成后端分拣中心
FlashInfer GEMM 的设计不是“我自己实现所有矩阵乘”,而是把不同库和自研 kernel 放在统一 API、JIT、autotune 和架构约束体系里。
mm_bf16 / mm_fp4 / Segment / Group
2. 一次调用的统一调度链路
读 flashinfer/gemm/gemm_base.py 时,抓住这个重复模式就不会迷路:
common shape check → backend requirement → heuristic → workspace → runner list → AutoTuner → backend module。
用户 API
flashinfer.gemm.mm_fp4(...) 或 mm_bf16(...)
参数检查
dtype、shape、layout、SM 架构、cuDNN/CUTLASS 能力。
候选后端
backend="auto" 时生成 suitable backend 列表。
Runner
每个后端包装成 TunableRunner。
AutoTuner
按 shape/profile 选择 runner + tactic,并缓存。
Kernel
调用具体 JIT module、cuDNN graph 或 cubin-backed runner。
mm_fp4(...)
├─ @backend_requirement(...)
├─ _check_mm_fp4_problem_size(...)
├─ _heuristic_func_mm_fp4(...)
├─ workspace = _get_cache_buf(...)
├─ runners = backend_to_runner_factory[backend](...)
└─ AutoTuner.choose_one("fp4_gemm", runners, tuning_config, inputs)
3. API 家族:按场景而不是按函数名记
BF16 dense
mm_bf16 / bmm_bf16 是最适合入门的调度入口。
当前代码事实
mm_bf16 包含 cublaslt 和 tinygemm 后端,dense BF16 的候选范围比只看 CUTLASS/cuDNN 更宽。
bias / PDL
bias 或 PDL 会把候选偏向 TGV/cuDNN/TinyGEMM;CUTLASS 和 cuBLASLt 路径会拒绝。
mm_fp8
当前是 low-latency TRT-LLM 路径,权重 b 需要按低延迟 GEMM 的布局预处理。
mm_mxfp8
核心是 block-scaled FP8。scale 可能是 2D non-swizzled,也可能是 1D swizzled,不能随意 reshape。
mm_fp4
覆盖 NVFP4/MXFP4。block_size=16 对应 NVFP4,block_size=32 对应 MXFP4。
SegmentGEMM
把 ragged segments 转换成 CUTLASS grouped GEMM 的 problem descriptors。
Group GEMM
MoE tokens 按 expert 分组后,每个 expert 一个小 GEMM,适合 groupwise 调度。
DeepGEMM
面向 Blackwell FP8 grouped contiguous / grouped masked MoE GEMM,重点处理 scale layout。
TGV
FlashInfer 自研 Blackwell small-M BF16/FP16 路径,支持 bias / PDL。
CuTe DSL
Python DSL 写 Blackwell block-scaled GEMM,显式组合 TMA、tcgen05、TMEM、pipeline。
RouterGEMM
固定 MoE router 形状专用 kernel,当前源码含 K7168_N128、K7168_N256、K6144_N256。
4. 后端选择矩阵
| 路径 | 适用场景 | 关键约束 |
|---|---|---|
| cuDNN | 通用 dense BF16/FP8/FP4/MXFP8,graph build/cache/execute。 | 依赖 cuDNN frontend/backend 版本;部分路径支持 override shape。 |
| CUTLASS JIT | SM100/103/120 上的 BF16、FP8、FP4、MXFP8 和 groupwise GEMM。 | 通过 Jinja 生成具体 CTA tile 实例;bias/PDL 能力受 wrapper 限制。 |
| TRT-LLM generated | 低精度 generated cubin 和 low-latency GEMM。 | 常要求权重 shuffle 或 128x4 swizzle;不应和普通 layout 后端混用。 |
| TGV | Blackwell small-M BF16/FP16,尤其是 decode dense layer。 | SM100/103,输出 dtype 受限,内部使用 TMA/UMMA/TMEM/barrier pipeline。 |
| CuTe DSL | Blackwell block-scaled FP4/MXFP8,透明 Python kernel。 | scale layout 要求严格;tactic 选择依赖 tile/cluster/PDL 等参数。 |
| DeepGEMM | DeepSeek 风格 FP8 groupwise MoE GEMM。 | 需要 scale factor transform、TMA aligned packed layout、driver launch。 |
| RouterGEMM | M=1..16 的固定 MoE router shape。 | 只服务少数模型形状,换取极低延迟。 |
5. JIT 生成:为什么代码看起来有这么多文件
低精度 GEMM 的组合空间很大:SM 架构、dtype、输出 dtype、CTA tile、scale layout、MMA SM 数、swap_ab 等都会生成不同实例。 FlashInfer 用 Jinja 和 JIT spec 把这些组合变成按需编译的 C++/CUDA 模块。
基础 module
gen_gemm_module() 编译 csrc/bmm_fp8.cu、csrc/group_gemm.cu、csrc/flashinfer_gemm_binding.cu。
CUTLASS 实例化
csrc/fp4_gemm_cutlass.jinja 等模板按 CTA tile 和 dtype 生成很多 .cu。
TRT-LLM cubin
不是源码生成 kernel,而是下载/缓存 generated GEMM 头和 cubin,再编译 runner。
当前代码中的 SM120 FP4 tile 包括 (128,32,128)、(128,32,256)、(128,64,128)、(128,64,256)、(128,128,128)、(128,128,256)、(256,128,128)、(128,256,128)。阅读 JIT 生成逻辑时,要把这些 tile 当作候选实例集合来理解。
6. 低精度 GEMM:重点是 scale 的形状和语义
FP8/FP4/MXFP8 GEMM 的大部分工程复杂度,不在矩阵乘公式,而在“scale tensor 该长什么样、如何 swizzle、哪个 backend 能读这种布局”。
数据格式
- FP8 e4m3/e5m2:常见 activation/weight 低精度。
- NVFP4:
block_size=16。 - MXFP4/MXFP8:
block_size=32或 scale vec size 32。
scale layout
- 2D non-swizzled:看起来像普通二维 scale。
- 1D swizzled:承载
F8_128x4等布局语义。 - TMA aligned packed:DeepGEMM/CuTe DSL 很在意。
风险点
- 不能把 swizzled 1D scale 随手 reshape 成 2D。
- TRT-LLM 后端可能需要 shuffle 后的权重。
- auto 后端不会选择需要不同输入语义的路径。
mm_fp4(a, b, a_descale, b_descale, ...)
block_size = 16 -> NVFP4
block_size = 32 -> MXFP4
mm_mxfp8(a, b, a_descale, b_descale, ...)
use_8x4_sf_layout=True -> swizzled scale path
backend="trtllm" -> requires preprocessed / shuffled weight semantics
backend="cute-dsl" -> requires compatible Blackwell scale layout
7. Segment / Group / DeepGEMM / RouterGEMM:为 MoE 和 ragged serving 服务
SegmentGEMM
输入是拼接后的 ragged tokens,通过 seg_lens/seg_indptr 映射每段,最终转成 grouped GEMM descriptors。
Group GEMM
MoE routing 后每个 expert token 数不同,group GEMM 一次提交多组 expert GEMM。
DeepGEMM
提供 grouped contiguous 和 grouped masked 两种 FP8 NT groupwise 路径。
RouterGEMM
专门优化 MoE router 的小 M 固定形状,当前支持 DeepSeek/Mistral/GLM 类形状。
SegmentGEMM:
x[sum(seg_lens), d_in] + weights[num_weights, d_out, d_in]
-> build problem arrays
-> cutlass_segment_gemm / cutlass_segment_gemm_sm90
DeepGEMM:
GroupedContiguous: packed tokens + m_indices
GroupedMasked: fixed expected_m + masked_m
8. Blackwell 特化:TGV 和 CuTe DSL 怎么理解
TGV GEMM
TGV 是 FlashInfer 自控的 Blackwell GEMM pipeline。它把角色拆成 TMA load、UMMA compute、TMEM accumulator、 epilogue 和 mbarrier 同步,所以可以更自然地支持 bias / PDL / small-M latency。
CuTe DSL GEMM
CuTe DSL 把 Blackwell kernel 结构提升到 Python:你能在 Python 文件中看到 tcgen05、TMA、cluster、 pipeline、scale factor tensor 和 TMEM copy 的组合方式。
9. 推荐源码阅读顺序
mm_bf16:先理解 backend requirement / heuristic / AutoTuner 的最完整入口。
mm_fp4:看低精度 GEMM 的 backend 分支、scale layout、auto 策略。
Jinja 生成 CUTLASS FP4 实例,理解 dtype/tile/SM 如何变成源码。
SegmentGEMMWrapper:理解 ragged segment 如何转换成 grouped GEMM problem arrays。
DeepGEMM scale factor TMA-aligned packed layout,是低精度 group GEMM 的关键。
RouterGEMM:看固定模型形状如何做极致特化。
10. 学习检查清单
能讲清楚
- 为什么 FlashInfer GEMM 是“多后端调度层”,不是一个 kernel。
backend_requirement和 heuristic 的区别。- 为什么 TRT-LLM/CuTe DSL 这类 layout 特殊后端不能随意进入 auto。
- FP4/MXFP8 中 scale layout 为什么比普通 GEMM 更重要。
- SegmentGEMM 和 GroupGEMM 分别解决哪类 serving 问题。
能定位代码
- API 调度:
flashinfer/gemm/gemm_base.py - JIT 生成:
flashinfer/jit/gemm/core.py - Blackwell DSL:
flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py - DeepGEMM:
flashinfer/deep_gemm.py - Router 特化:
flashinfer/gemm/routergemm.py