每个线程算一个 C 元素,K 次乘加都从 global memory 取 A/B。
1. 文章主线:GEMM 优化阶梯
你给的文章重点是 CUDA SGEMM 的逐步优化:先实现正确,再分析瓶颈, 然后逐层把数据搬运、线程映射、shared memory、寄存器和流水线安排好。
CPU / naive CUDA
每个 C 元素独立,M×N 个线程,每个线程沿 K 做点积。
Shared Memory 分块
C 按 BM×BN 分块,A/B 沿 BK 切片加载到 shared memory。
Bank Conflict 修正
改变 shared memory 布局和线程 tile 形状,让一个 warp 的访存更均匀。
Double Buffering
使用两份 shared buffer,让下一块加载和当前块计算重叠。
cuBLAS / CUTLASS
把 block tile、warp tile、thread tile 做成成熟层级,接近硬件上限。
2. 从公式到性能问题
GEMM 公式很短:C = A × B。但高性能实现关注的不是公式本身,
而是每个数据从哪里来、被复用多少次、最后什么时候写回。
数学视角
每个输出元素是 A 的一行与 B 的一列做内积。
性能视角
GEMM 的优化目标是让一次 global memory 读取服务更多 FFMA / MMA 计算。
Naive CUDA
线程数很多,但每个线程都直接读 global A/B,内存带宽很快成为瓶颈。
Shared Memory 版本
先把 A/B tile 放进 shared memory,让同一 block 内多个线程共享。
高性能版本
继续处理 bank conflict、寄存器 tile、双缓冲、warp-level tile 和 Tensor Core。
3. 算术强度计算器:为什么 BM/BN 越大越香
对一个 C tile 来说,每个 K 分块需要加载 A tile 和 B tile。计算量约为
2 × BM × BN × BK FLOPs,读取量约为 4 × BK × (BM + BN) bytes。
所以算术强度近似为 BM × BN / (2 × (BM + BN)) FLOP/byte。
smem_bytes = 4 * BK * (BM + BN)
flops_per_k_tile = 2 * BM * BN * BK
arithmetic_intensity = flops_per_k_tile / smem_bytes
threads_per_block = (BM / TM) * (BN / TN)
4. 分块不是切图,而是安排复用
一次 block tile 计算中,A 的一小片会服务 BN 方向的多个输出,B 的一小片会服务 BM 方向的多个输出。 这就是 shared memory 分块的核心收益。
蓝色表示当前加载的 A 子块,橙色表示当前加载的 B 子块,绿色表示正在累计的 C 子块。 拖动 K 分块可以看到:C tile 不变,A/B tile 沿 K 方向前进。
5. Bank Conflict:shared memory 也会堵车
文章里提到 shared memory 通常有 32 个 bank。一个 warp 的 32 个线程如果均匀落到 32 个 bank, 一拍就能完成;如果多个线程访问同一个 bank,就会被拆成多拍。
模拟公式:bank = (lane × stride) mod 32。
stride=1 时基本无冲突;stride=2 会形成 2-way conflict;stride=8 会更糟。
为什么 A tile 可能冲突
如果 A 在 shared memory 中按行放,但计算阶段要按列取,warp 内线程地址会跨步,容易落到重复 bank。
文章里的修正思路
把 A 的 shared memory 布局改成近似转置的形态,例如 s_a[BK][BM],让计算阶段访问更顺。
6. Double Buffering:把串行改成流水线
单缓冲是“加载 tile → 同步 → 计算 tile → 同步 → 加载下一个 tile”。双缓冲用两套 shared memory: 当前 buffer 用于计算,另一个 buffer 接收下一块数据。
代价
shared memory 使用量翻倍:从 BK×(BM+BN) 到 2×BK×(BM+BN)。
收益
LDG 加载不再完全挡住 FFMA 计算,访存和计算可以部分重叠。
关键细节
先预取第一块,主循环从下一块开始,最后还要补一次尾部计算。
7. cuBLAS / CUTLASS 的层级为什么更稳
文章最后提到 cuBLAS/CUTLASS 的思想:把 GEMM 分解为 thread block tile、warp tile、thread tile, 这正好对应 CUDA 编程模型和 GPU 的内存层级。
block tile
决定一个 thread block 输出多大 C 区域,以及 shared memory 中放多大 A/B tile。
warp tile
把 block tile 分给多个 warp,适配 warp-level 指令和 Tensor Core MMA。
thread tile
每个线程持有一小块 accumulator,减少 C 的中间读写。
8. 用 Triton 复述这套思想
Triton 不要求你手动写 shared memory 和 threadIdx,但它的 block program 依然是在做同一件事:
一个 program 负责一个 C tile,沿 K 分块加载 A/B,调用 tl.dot 累加。
Triton 里对应的概念
| CUDA/CUTLASS | Triton |
|---|---|
| thread block tile | 一个 program 的 BLOCK_M × BLOCK_N |
| K 方向分块 | for k in range(0, K, BLOCK_K) |
| 寄存器 accumulator | acc = tl.zeros(..., tl.float32) |
| Tensor Core / MMA | tl.dot(a, b) 在合适 dtype/layout 下映射到底层矩阵指令 |
| 边界判断 | mask 控制越界 load/store |
进阶 Triton 骨架
@triton.jit
def matmul(A, B, C, M, N, K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
pid_m = pid // tl.cdiv(N, BLOCK_N)
pid_n = pid % tl.cdiv(N, BLOCK_N)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
for k0 in range(0, K, BLOCK_K):
a = tl.load(A + offs_m[:, None] * K + k0 + offs_k[None, :])
b = tl.load(B + (k0 + offs_k[:, None]) * N + offs_n[None, :])
acc += tl.dot(a, b)
tl.store(C + offs_m[:, None] * N + offs_n[None, :], acc)
9. 进阶学习时最容易卡住的点
| 卡点 | 正确理解 | 怎么检查 |
|---|---|---|
| Occupancy 越高越好吗 | 不一定。GEMM 常常愿意用更多寄存器换更高复用和算术强度。 | 看 SM 吞吐、eligible warps、memory stall,而不是只看 occupancy。 |
| shared memory 一定更快吗 | 理论上快,但 bank conflict 会让访问被拆成多拍。 | 看 shared load/store replay 或 bank conflict 相关指标。 |
| BM/BN 越大越好吗 | 算术强度会上升,但 shared memory、寄存器、block 数和调度都会受影响。 | 用 autotune 或至少扫几组 shape,不要只凭公式。 |
| Double buffering 一定提速吗 | 当 load latency 能被 compute 覆盖时收益明显,否则可能只是增加资源占用。 | 对比单缓冲版本,看 memory dependency stall 是否下降。 |
| 手写能超过 cuBLAS 吗 | 特定 shape/融合 epilogue 有机会,通用 SGEMM 很难长期胜过成熟库。 | 对齐 dtype、layout、warmup、计时方法后再比较。 |
10. 复习清单
你应该能说清楚
- 为什么 naive CUDA 只用到很少峰值算力。
- BM、BN、BK、TM、TN 分别控制什么。
- 为什么
BM=BN=128, BK=8这类配置能提高算术强度。 - 为什么 shared memory 还要关心 bank conflict。
- double buffering 是怎样把访存和计算重叠的。
下一步可以做
- 用 Triton 写一个基础 matmul,再加 autotune。
- 用 Nsight Compute 看 memory throughput、shared replay、SM utilization。
- 把 epilogue 融合进去,例如 bias、scale、ReLU/GELU。
- 对比 naive、shared memory、Triton、cuBLAS 的同 shape 性能。
- 继续学习 Tensor Core MMA 指令和 CUTLASS tile hierarchy。
参考来源
知乎原文页面在当前环境返回 403,因此我阅读了标注原文链接的转载镜像,并按自己的结构重写为学习文档。