FlashInfer GEMM

Learning Notes

FlashInfer GEMM 子系统:多精度矩阵乘、JIT 生成与多后端调度

基于 Git 仓库 flashinfer-ai/flashinfer 的代码整理;源码快照 commit:fc12ef21d9eddcd9a7b1a4a8024fb3e81fda804d

这份学习页完全按当前代码组织内容。重点不是罗列所有 kernel,而是帮你建立“一个 GEMM API 为什么会走到不同后端”的学习模型。

GEMM 子系统不是单个 kernel

它是一层调度系统,统一包装 cuDNN、CUTLASS、TRT-LLM、TGV、CuTe DSL、DeepGEMM 和特化 kernel。

低精度复杂度在 scale layout

FP8/FP4/MXFP8 的难点常常不是 matmul,而是 scale granularity、swizzle、TMA alignment。

auto 是保守选择

需要特殊权重预处理的后端不会随便进入 auto,避免相同 tensor 被不同 layout 语义误读。

serving 需要 small-M 专用路径

TGV、TinyGEMM、TRT-LLM low latency、RouterGEMM 都是为了压 decode/MoE 热路径延迟。

1. 学习地图:把 GEMM 当成后端分拣中心

FlashInfer GEMM 的设计不是“我自己实现所有矩阵乘”,而是把不同库和自研 kernel 放在统一 API、JIT、autotune 和架构约束体系里。

Python API
mm_bf16 / mm_fp4 / Segment / Group
Requirement + Heuristic
先判断能不能跑,再决定 auto 候选
AutoTuner + Runner
为候选后端和 tactic 做 profile/cache
JIT / Binding
生成 CUTLASS 实例、加载 cubin、注册 custom op
Kernel Backend
cuDNN / CUTLASS / TRT-LLM / TGV / CuTe DSL / DeepGEMM

2. 一次调用的统一调度链路

flashinfer/gemm/gemm_base.py 时,抓住这个重复模式就不会迷路: common shape check → backend requirement → heuristic → workspace → runner list → AutoTuner → backend module。

1

用户 API

flashinfer.gemm.mm_fp4(...)mm_bf16(...)

2

参数检查

dtype、shape、layout、SM 架构、cuDNN/CUTLASS 能力。

3

候选后端

backend="auto" 时生成 suitable backend 列表。

4

Runner

每个后端包装成 TunableRunner

5

AutoTuner

按 shape/profile 选择 runner + tactic,并缓存。

6

Kernel

调用具体 JIT module、cuDNN graph 或 cubin-backed runner。

mm_fp4(...)
  ├─ @backend_requirement(...)
  ├─ _check_mm_fp4_problem_size(...)
  ├─ _heuristic_func_mm_fp4(...)
  ├─ workspace = _get_cache_buf(...)
  ├─ runners = backend_to_runner_factory[backend](...)
  └─ AutoTuner.choose_one("fp4_gemm", runners, tuning_config, inputs)

3. API 家族:按场景而不是按函数名记

BF16 dense

mm_bf16 / bmm_bf16 是最适合入门的调度入口。

cuDNN CUTLASS TGV cuBLASLt TinyGEMM

当前代码事实

mm_bf16 包含 cublaslttinygemm 后端,dense BF16 的候选范围比只看 CUTLASS/cuDNN 更宽。

bias / PDL

bias 或 PDL 会把候选偏向 TGV/cuDNN/TinyGEMM;CUTLASS 和 cuBLASLt 路径会拒绝。

mm_fp8

当前是 low-latency TRT-LLM 路径,权重 b 需要按低延迟 GEMM 的布局预处理。

mm_mxfp8

核心是 block-scaled FP8。scale 可能是 2D non-swizzled,也可能是 1D swizzled,不能随意 reshape。

mm_fp4

覆盖 NVFP4/MXFP4。block_size=16 对应 NVFP4,block_size=32 对应 MXFP4。

SegmentGEMM

把 ragged segments 转换成 CUTLASS grouped GEMM 的 problem descriptors。

Group GEMM

MoE tokens 按 expert 分组后,每个 expert 一个小 GEMM,适合 groupwise 调度。

DeepGEMM

面向 Blackwell FP8 grouped contiguous / grouped masked MoE GEMM,重点处理 scale layout。

TGV

FlashInfer 自研 Blackwell small-M BF16/FP16 路径,支持 bias / PDL。

CuTe DSL

Python DSL 写 Blackwell block-scaled GEMM,显式组合 TMA、tcgen05、TMEM、pipeline。

RouterGEMM

固定 MoE router 形状专用 kernel,当前源码含 K7168_N128、K7168_N256、K6144_N256。

4. 后端选择矩阵

路径 适用场景 关键约束
cuDNN 通用 dense BF16/FP8/FP4/MXFP8,graph build/cache/execute。 依赖 cuDNN frontend/backend 版本;部分路径支持 override shape。
CUTLASS JIT SM100/103/120 上的 BF16、FP8、FP4、MXFP8 和 groupwise GEMM。 通过 Jinja 生成具体 CTA tile 实例;bias/PDL 能力受 wrapper 限制。
TRT-LLM generated 低精度 generated cubin 和 low-latency GEMM。 常要求权重 shuffle 或 128x4 swizzle;不应和普通 layout 后端混用。
TGV Blackwell small-M BF16/FP16,尤其是 decode dense layer。 SM100/103,输出 dtype 受限,内部使用 TMA/UMMA/TMEM/barrier pipeline。
CuTe DSL Blackwell block-scaled FP4/MXFP8,透明 Python kernel。 scale layout 要求严格;tactic 选择依赖 tile/cluster/PDL 等参数。
DeepGEMM DeepSeek 风格 FP8 groupwise MoE GEMM。 需要 scale factor transform、TMA aligned packed layout、driver launch。
RouterGEMM M=1..16 的固定 MoE router shape。 只服务少数模型形状,换取极低延迟。

5. JIT 生成:为什么代码看起来有这么多文件

低精度 GEMM 的组合空间很大:SM 架构、dtype、输出 dtype、CTA tile、scale layout、MMA SM 数、swap_ab 等都会生成不同实例。 FlashInfer 用 Jinja 和 JIT spec 把这些组合变成按需编译的 C++/CUDA 模块。

TRT-LLM cubin

不是源码生成 kernel,而是下载/缓存 generated GEMM 头和 cubin,再编译 runner。

SM120 FP4 tile 事实

当前代码中的 SM120 FP4 tile 包括 (128,32,128)(128,32,256)(128,64,128)(128,64,256)(128,128,128)(128,128,256)(256,128,128)(128,256,128)。阅读 JIT 生成逻辑时,要把这些 tile 当作候选实例集合来理解。

6. 低精度 GEMM:重点是 scale 的形状和语义

FP8/FP4/MXFP8 GEMM 的大部分工程复杂度,不在矩阵乘公式,而在“scale tensor 该长什么样、如何 swizzle、哪个 backend 能读这种布局”。

数据格式

  • FP8 e4m3/e5m2:常见 activation/weight 低精度。
  • NVFP4:block_size=16
  • MXFP4/MXFP8:block_size=32 或 scale vec size 32。

scale layout

  • 2D non-swizzled:看起来像普通二维 scale。
  • 1D swizzled:承载 F8_128x4 等布局语义。
  • TMA aligned packed:DeepGEMM/CuTe DSL 很在意。

风险点

  • 不能把 swizzled 1D scale 随手 reshape 成 2D。
  • TRT-LLM 后端可能需要 shuffle 后的权重。
  • auto 后端不会选择需要不同输入语义的路径。
mm_fp4(a, b, a_descale, b_descale, ...)
  block_size = 16 -> NVFP4
  block_size = 32 -> MXFP4

mm_mxfp8(a, b, a_descale, b_descale, ...)
  use_8x4_sf_layout=True  -> swizzled scale path
  backend="trtllm"        -> requires preprocessed / shuffled weight semantics
  backend="cute-dsl"      -> requires compatible Blackwell scale layout

7. Segment / Group / DeepGEMM / RouterGEMM:为 MoE 和 ragged serving 服务

SegmentGEMM

输入是拼接后的 ragged tokens,通过 seg_lens/seg_indptr 映射每段,最终转成 grouped GEMM descriptors。

Group GEMM

MoE routing 后每个 expert token 数不同,group GEMM 一次提交多组 expert GEMM。

DeepGEMM

提供 grouped contiguous 和 grouped masked 两种 FP8 NT groupwise 路径。

RouterGEMM

专门优化 MoE router 的小 M 固定形状,当前支持 DeepSeek/Mistral/GLM 类形状。

SegmentGEMM:
  x[sum(seg_lens), d_in] + weights[num_weights, d_out, d_in]
  -> build problem arrays
  -> cutlass_segment_gemm / cutlass_segment_gemm_sm90

DeepGEMM:
  GroupedContiguous: packed tokens + m_indices
  GroupedMasked: fixed expected_m + masked_m

8. Blackwell 特化:TGV 和 CuTe DSL 怎么理解

TGV GEMM

TGV 是 FlashInfer 自控的 Blackwell GEMM pipeline。它把角色拆成 TMA load、UMMA compute、TMEM accumulator、 epilogue 和 mbarrier 同步,所以可以更自然地支持 bias / PDL / small-M latency。

TMA UMMA TMEM mbarrier PDL

CuTe DSL GEMM

CuTe DSL 把 Blackwell kernel 结构提升到 Python:你能在 Python 文件中看到 tcgen05、TMA、cluster、 pipeline、scale factor tensor 和 TMEM copy 的组合方式。

tcgen05 PipelineTmaUmma scale factor cluster

9. 推荐源码阅读顺序

flashinfer/gemm/gemm_base.py#L485

mm_bf16:先理解 backend requirement / heuristic / AutoTuner 的最完整入口。

flashinfer/gemm/gemm_base.py#L5885

mm_fp4:看低精度 GEMM 的 backend 分支、scale layout、auto 策略。

flashinfer/jit/gemm/core.py#L85

Jinja 生成 CUTLASS FP4 实例,理解 dtype/tile/SM 如何变成源码。

flashinfer/gemm/gemm_base.py#L1736

SegmentGEMMWrapper:理解 ragged segment 如何转换成 grouped GEMM problem arrays。

flashinfer/deep_gemm.py#L132

DeepGEMM scale factor TMA-aligned packed layout,是低精度 group GEMM 的关键。

flashinfer/gemm/routergemm.py#L169

RouterGEMM:看固定模型形状如何做极致特化。

10. 学习检查清单

能讲清楚

  • 为什么 FlashInfer GEMM 是“多后端调度层”,不是一个 kernel。
  • backend_requirement 和 heuristic 的区别。
  • 为什么 TRT-LLM/CuTe DSL 这类 layout 特殊后端不能随意进入 auto。
  • FP4/MXFP8 中 scale layout 为什么比普通 GEMM 更重要。
  • SegmentGEMM 和 GroupGEMM 分别解决哪类 serving 问题。

参考文件