真正贵的是反复从慢内存搬相同的 A/B 数据。
1. 先建立心智模型
矩阵乘法的每个输出元素,都是 A 的一行和 B 的一列做点积。优化的本质是: 让这条行和这条列的数据尽量少从 global memory 取,多在 cache/shared memory/register 中复用。
C[i,j] = Σ A[i,k] × B[k,j]对所有
k ∈ [0, K) 累加。
M 维
输出 C 的行数。通常对应 A 的 row 方向。
N 维
输出 C 的列数。通常对应 B 的 column 方向。
K 维
归约维。每个 C 元素要沿 K 做乘加累积。
2. 乘法过程可视化
调整下面的 i、j 和 k,观察一个输出元素如何由 A 的一行和 B 的一列累加得到。
3. 朴素实现为什么慢
CPU/Python 视角的三重循环
for i in range(M):
for j in range(N):
acc = 0.0
for k in range(K):
acc += A[i, k] * B[k, j]
C[i, j] = acc
这个版本容易理解,但每个 C 元素都重复读取 A 的一行和 B 的一列,内存复用差。
慢点在哪里
- A/B 数据被反复从 global memory 读取。
- B 的列访问在 row-major 布局下不连续,容易不 coalesced。
- 每次只算一个 C 元素,寄存器和并行度利用不充分。
- 没有把 K 方向切块,cache/shared memory 复用很弱。
4. Tiling:把大问题切成小块
高性能 GEMM 的第一原则是切块。一个 thread block 或 Triton program 负责一块
C_tile[BLOCK_M, BLOCK_N],沿 K 分段加载
A_tile[BLOCK_M, BLOCK_K] 和 B_tile[BLOCK_K, BLOCK_N],累加到寄存器里的 accumulator。
蓝色是当前 A tile,橙色是当前 B tile,绿色是正在累加的 C tile。 K tile 每前进一步,C tile 上的 accumulator 多累加一段部分和。
BLOCK_M × BLOCK_N
决定一个 program/block 输出多少 C 元素。越大复用越好,但寄存器压力也越大。
BLOCK_K
决定每次从 K 方向加载多厚的一片。太小搬运次数多,太大 shared/register 压力高。
Accumulator
中间结果通常保存在寄存器里的 FP32 accumulator,最后再写回 C。
5. 内存层级:优化就是减少慢内存访问
GPU 的算力很强,global memory 却相对慢。GEMM 要快,必须把 A/B tile 搬到更快的层级, 并让每次搬运服务尽可能多的乘加。
Coalescing
让相邻线程读相邻地址,减少 memory transaction。
Shared reuse
一个 A tile 被多个 N 方向输出复用,一个 B tile 被多个 M 方向输出复用。
Register tiling
每个线程算多个 C 元素,accumulator 放寄存器。
Double buffering
一边算当前 tile,一边预取下一块 tile,隐藏内存延迟。
6. 常用优化方法总表
| 方法 | 解决的问题 | 关键取舍 |
|---|---|---|
| Tiling / Blocking | 减少 A/B 反复从 global memory 读取。 | tile 太小复用差,太大寄存器/shared memory 压力高。 |
| Shared memory staging | 把 global memory 的慢访问变成片上复用。 | 需要处理 bank conflict、同步和容量限制。 |
| Register tiling | 减少 C 中间结果读写,提高每线程算术强度。 | 寄存器用太多会降低 occupancy。 |
| Vectorized / coalesced loads | 提高内存吞吐,减少零散访存。 | 依赖对齐、连续布局和边界处理。 |
| Loop unrolling | 减少循环控制开销,暴露更多指令级并行。 | 代码尺寸和寄存器压力可能增加。 |
| Tensor Core / MMA | 利用专用矩阵乘加单元获得数量级吞吐提升。 | 需要合适 dtype、layout、tile shape 和累加精度。 |
| Double buffering / pipelining | 用计算隐藏下一块数据加载延迟。 | 实现更复杂,占用更多 shared memory/register。 |
| Epilogue fusion | 把 bias、activation、scale 等后处理合进 GEMM 写回前。 | 减少额外 kernel 和 memory round trip。 |
| Autotuning | 不同 M/N/K、dtype、GPU 上最佳 tile 不同。 | 用搜索换性能,Triton 特别常用。 |
7. CUDA 写法:从 block/thread 映射理解
CUDA 版本通常显式控制 thread block、shared memory、同步和每个线程负责的输出元素。 下面是概念版,不追求完整可编译,重点看数据如何移动。
朴素 CUDA:每线程一个 C 元素
__global__ void matmul_naive(A, B, C, M, N, K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
acc += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = acc;
}
简单,但 A/B 全靠 global memory,复用差。
Shared memory tiled CUDA
for (int kt = 0; kt < K; kt += TILE_K) {
As[ty][tx] = A[row, kt + tx];
Bs[ty][tx] = B[kt + ty, col];
__syncthreads();
for (int kk = 0; kk < TILE_K; ++kk) {
acc += As[ty][kk] * Bs[kk][tx];
}
__syncthreads();
}
C[row, col] = acc;
A/B tile 先进入 shared memory,然后被 block 内多个线程复用。
CUDA 优化时的三问
- 线程访问 global memory 是否 coalesced?B 的访问是否需要转置、重排或向量化?
- shared memory 是否有 bank conflict?是否需要 padding?
- 每个线程算几个 C 元素?寄存器压力和 occupancy 是否平衡?
8. Triton 写法:用 block program 表达 GEMM
Triton 更像是在写“一个 program 负责一个 C tile”。你不用手动写每个线程,
但仍然要选择 BLOCK_M、BLOCK_N、BLOCK_K,
并理解 mask、stride、accumulator 和 autotune。
Triton matmul 骨架
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr):
pid = tl.program_id(0)
pid_m = pid // tl.cdiv(N, BLOCK_N)
pid_n = pid % tl.cdiv(N, BLOCK_N)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
for k0 in range(0, K, BLOCK_K):
a = tl.load(A + offs_m[:, None] * stride_am
+ (k0 + offs_k[None, :]) * stride_ak,
mask=(offs_m[:, None] < M) & (k0 + offs_k[None, :] < K))
b = tl.load(B + (k0 + offs_k[:, None]) * stride_bk
+ offs_n[None, :] * stride_bn,
mask=(k0 + offs_k[:, None] < K) & (offs_n[None, :] < N))
acc += tl.dot(a, b)
tl.store(C + offs_m[:, None] * stride_cm
+ offs_n[None, :] * stride_cn,
acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
读 Triton kernel 的关键
pid_m/pid_n决定当前 program 负责哪个 C tile。offs_m/offs_n/offs_k构造 block 内的二维地址。tl.load一次加载 A/B tile,边界靠 mask 处理。tl.dot把 tile 乘加到 FP32 accumulator。tl.store最后一次写回 C,可在此融合 bias/activation。
常见配置
BLOCK_M/N 常见 16、32、64、128;BLOCK_K 常见 32、64。
Autotune
同一 kernel 用多组 block size、num_warps、num_stages 搜索最优。
Tensor Core
当 dtype 和 shape 合适时,tl.dot 会走 MMA/Tensor Core 路径。
9. 从慢到快的学习路线
写出朴素三重循环
目标是完全理解 C[i,j] 的点积过程和 M/N/K 三个维度。
实现 tiled 版本
把 C 切成 tile,沿 K 分段累加。先不追求最快,只追求正确理解数据复用。
观察内存访问
检查 global load 是否连续、shared memory 是否复用、C 是否只写一次。
加入寄存器 tiling 和 pipelining
每个线程负责多个输出元素,并用 double buffering 隐藏下一块 tile 的加载。
迁移到 Tensor Core / Triton autotune
让 tl.dot 或 CUDA MMA 指令承担 tile 乘加,用 autotune 找适合当前 shape 的参数。
10. 优化检查清单
性能视角
- 是不是 memory-bound?算术强度是否足够?
- A/B tile 是否被多个 C 元素复用?
- global load 是否 coalesced,是否可以 vectorize?
- shared memory 是否出现 bank conflict?
- 寄存器是否太多导致 occupancy 下降?
正确性视角
- 边界 M/N/K 不是 block size 倍数时 mask 是否正确?
- FP16/BF16 输入是否使用 FP32 accumulator?
- 转置、stride、layout 是否和真实张量一致?
- 融合 bias/activation 后数值是否和 reference 对齐?
- 不同 batch/shape 下 autotune 参数是否仍然安全?