GEMM 进阶理解

Advanced GEMM Notes

矩阵乘法进阶理解:从 naive kernel 到接近 cuBLAS 的优化路径

这份文档基于你给的 GEMM 文章脉络重写:不是再讲一遍矩阵乘法定义, 而是把每一步优化背后的性能动机讲清楚,包括分块、访存复用、bank conflict、double buffering、 CUDA/CUTLASS 层级,以及 Triton 里如何表达同一套思想。

Naive 慢在访存复用差

每个线程算一个 C 元素,K 次乘加都从 global memory 取 A/B。

分块提高算术强度

一个 A/B tile 被多个输出元素复用,单位字节能产生更多 FLOPs。

bank conflict 会吃掉 shared memory 优势

shared memory 快,但访问模式不对会串行化。

双缓冲让搬运和计算重叠

把“load 后 compute”的串行流程改造成流水线。

1. 文章主线:GEMM 优化阶梯

你给的文章重点是 CUDA SGEMM 的逐步优化:先实现正确,再分析瓶颈, 然后逐层把数据搬运、线程映射、shared memory、寄存器和流水线安排好。

1

CPU / naive CUDA

每个 C 元素独立,M×N 个线程,每个线程沿 K 做点积。

2

Shared Memory 分块

C 按 BM×BN 分块,A/B 沿 BK 切片加载到 shared memory。

3

Bank Conflict 修正

改变 shared memory 布局和线程 tile 形状,让一个 warp 的访存更均匀。

4

Double Buffering

使用两份 shared buffer,让下一块加载和当前块计算重叠。

5

cuBLAS / CUTLASS

把 block tile、warp tile、thread tile 做成成熟层级,接近硬件上限。

2. 从公式到性能问题

GEMM 公式很短:C = A × B。但高性能实现关注的不是公式本身, 而是每个数据从哪里来、被复用多少次、最后什么时候写回。

数学视角

C[i,j] = Σ A[i,k] × B[k,j]

每个输出元素是 A 的一行与 B 的一列做内积。

性能视角

快 = 少搬运 + 多复用 + 好并行

GEMM 的优化目标是让一次 global memory 读取服务更多 FFMA / MMA 计算。

Naive CUDA

线程数很多,但每个线程都直接读 global A/B,内存带宽很快成为瓶颈。

Shared Memory 版本

先把 A/B tile 放进 shared memory,让同一 block 内多个线程共享。

高性能版本

继续处理 bank conflict、寄存器 tile、双缓冲、warp-level tile 和 Tensor Core。

3. 算术强度计算器:为什么 BM/BN 越大越香

对一个 C tile 来说,每个 K 分块需要加载 A tile 和 B tile。计算量约为 2 × BM × BN × BK FLOPs,读取量约为 4 × BK × (BM + BN) bytes。 所以算术强度近似为 BM × BN / (2 × (BM + BN)) FLOP/byte。

A/B shared memory 8192 B
算术强度 32.0
线程数估算 256
C 累加寄存器/线程 64
smem_bytes = 4 * BK * (BM + BN)
flops_per_k_tile = 2 * BM * BN * BK
arithmetic_intensity = flops_per_k_tile / smem_bytes
threads_per_block = (BM / TM) * (BN / TN)

4. 分块不是切图,而是安排复用

一次 block tile 计算中,A 的一小片会服务 BN 方向的多个输出,B 的一小片会服务 BM 方向的多个输出。 这就是 shared memory 分块的核心收益。

蓝色表示当前加载的 A 子块,橙色表示当前加载的 B 子块,绿色表示正在累计的 C 子块。 拖动 K 分块可以看到:C tile 不变,A/B tile 沿 K 方向前进。

A[BM,BK]
×
B[BK,BN]
C[BM,BN]

5. Bank Conflict:shared memory 也会堵车

文章里提到 shared memory 通常有 32 个 bank。一个 warp 的 32 个线程如果均匀落到 32 个 bank, 一拍就能完成;如果多个线程访问同一个 bank,就会被拆成多拍。

模拟公式:bank = (lane × stride) mod 32。 stride=1 时基本无冲突;stride=2 会形成 2-way conflict;stride=8 会更糟。

最大冲突度 1-way
被访问 bank 数 32

为什么 A tile 可能冲突

如果 A 在 shared memory 中按行放,但计算阶段要按列取,warp 内线程地址会跨步,容易落到重复 bank。

文章里的修正思路

把 A 的 shared memory 布局改成近似转置的形态,例如 s_a[BK][BM],让计算阶段访问更顺。

6. Double Buffering:把串行改成流水线

单缓冲是“加载 tile → 同步 → 计算 tile → 同步 → 加载下一个 tile”。双缓冲用两套 shared memory: 当前 buffer 用于计算,另一个 buffer 接收下一块数据。

Single Buffer
Load 0
Compute 0
Load 1
Compute 1
Load 2
Compute 2
Double Buffer
Preload 0
Compute 0 + Load 1
Compute 1 + Load 2
Compute 2 + Load 3
Compute 3
Tail

代价

shared memory 使用量翻倍:从 BK×(BM+BN)2×BK×(BM+BN)

收益

LDG 加载不再完全挡住 FFMA 计算,访存和计算可以部分重叠。

关键细节

先预取第一块,主循环从下一块开始,最后还要补一次尾部计算。

7. cuBLAS / CUTLASS 的层级为什么更稳

文章最后提到 cuBLAS/CUTLASS 的思想:把 GEMM 分解为 thread block tile、warp tile、thread tile, 这正好对应 CUDA 编程模型和 GPU 的内存层级。

Global Memory
A/B 大矩阵
Thread Block Tile
搬入 shared memory
Warp Tile
warp 协作处理
Thread Tile
寄存器累加

block tile

决定一个 thread block 输出多大 C 区域,以及 shared memory 中放多大 A/B tile。

warp tile

把 block tile 分给多个 warp,适配 warp-level 指令和 Tensor Core MMA。

thread tile

每个线程持有一小块 accumulator,减少 C 的中间读写。

8. 用 Triton 复述这套思想

Triton 不要求你手动写 shared memory 和 threadIdx,但它的 block program 依然是在做同一件事: 一个 program 负责一个 C tile,沿 K 分块加载 A/B,调用 tl.dot 累加。

Triton 里对应的概念

CUDA/CUTLASS Triton
thread block tile 一个 program 的 BLOCK_M × BLOCK_N
K 方向分块 for k in range(0, K, BLOCK_K)
寄存器 accumulator acc = tl.zeros(..., tl.float32)
Tensor Core / MMA tl.dot(a, b) 在合适 dtype/layout 下映射到底层矩阵指令
边界判断 mask 控制越界 load/store

进阶 Triton 骨架

@triton.jit
def matmul(A, B, C, M, N, K,
           BLOCK_M: tl.constexpr,
           BLOCK_N: tl.constexpr,
           BLOCK_K: tl.constexpr):
    pid = tl.program_id(0)
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid %  tl.cdiv(N, BLOCK_N)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)

    for k0 in range(0, K, BLOCK_K):
        a = tl.load(A + offs_m[:, None] * K + k0 + offs_k[None, :])
        b = tl.load(B + (k0 + offs_k[:, None]) * N + offs_n[None, :])
        acc += tl.dot(a, b)

    tl.store(C + offs_m[:, None] * N + offs_n[None, :], acc)

9. 进阶学习时最容易卡住的点

卡点 正确理解 怎么检查
Occupancy 越高越好吗 不一定。GEMM 常常愿意用更多寄存器换更高复用和算术强度。 看 SM 吞吐、eligible warps、memory stall,而不是只看 occupancy。
shared memory 一定更快吗 理论上快,但 bank conflict 会让访问被拆成多拍。 看 shared load/store replay 或 bank conflict 相关指标。
BM/BN 越大越好吗 算术强度会上升,但 shared memory、寄存器、block 数和调度都会受影响。 用 autotune 或至少扫几组 shape,不要只凭公式。
Double buffering 一定提速吗 当 load latency 能被 compute 覆盖时收益明显,否则可能只是增加资源占用。 对比单缓冲版本,看 memory dependency stall 是否下降。
手写能超过 cuBLAS 吗 特定 shape/融合 epilogue 有机会,通用 SGEMM 很难长期胜过成熟库。 对齐 dtype、layout、warmup、计时方法后再比较。

10. 复习清单

你应该能说清楚

  • 为什么 naive CUDA 只用到很少峰值算力。
  • BM、BN、BK、TM、TN 分别控制什么。
  • 为什么 BM=BN=128, BK=8 这类配置能提高算术强度。
  • 为什么 shared memory 还要关心 bank conflict。
  • double buffering 是怎样把访存和计算重叠的。

下一步可以做

  • 用 Triton 写一个基础 matmul,再加 autotune。
  • 用 Nsight Compute 看 memory throughput、shared replay、SM utilization。
  • 把 epilogue 融合进去,例如 bias、scale、ReLU/GELU。
  • 对比 naive、shared memory、Triton、cuBLAS 的同 shape 性能。
  • 继续学习 Tensor Core MMA 指令和 CUTLASS tile hierarchy。

参考来源

知乎原文页面在当前环境返回 403,因此我阅读了标注原文链接的转载镜像,并按自己的结构重写为学习文档。