矩阵乘法优化

Visual Notes

矩阵乘法优化:从三重循环到 CUDA / Triton 高性能 kernel

这是一份面向理解的学习文档:先看清 C = A × B 的乘法过程, 再逐步引入 tiling、共享内存、寄存器累加、Tensor Core、Triton block program 和性能检查方法。

核心瓶颈不是乘法本身

真正贵的是反复从慢内存搬相同的 A/B 数据。

Tiling 是第一优化

把大矩阵切块,让一块 A 和一块 B 被多次复用。

累加放在寄存器

每个线程或 warp 维护多个 C 元素,减少读写 C 的次数。

Triton 抽象 block

你写的是 block-level 程序,编译器负责映射到 GPU 指令。

1. 先建立心智模型

矩阵乘法的每个输出元素,都是 A 的一行和 B 的一列做点积。优化的本质是: 让这条行和这条列的数据尽量少从 global memory 取,多在 cache/shared memory/register 中复用。

C[M,N] = A[M,K] × B[K,N]
单个元素:
C[i,j] = Σ A[i,k] × B[k,j]
对所有 k ∈ [0, K) 累加。

M 维

输出 C 的行数。通常对应 A 的 row 方向。

rows of A/C

N 维

输出 C 的列数。通常对应 B 的 column 方向。

cols of B/C

K 维

归约维。每个 C 元素要沿 K 做乘加累积。

reduction

2. 乘法过程可视化

调整下面的 ijk,观察一个输出元素如何由 A 的一行和 B 的一列累加得到。

A 的行/元素 B 的列/元素 C 的输出
A[4,4]
×
B[4,4]
=
C[4,4]

3. 朴素实现为什么慢

CPU/Python 视角的三重循环

for i in range(M):
    for j in range(N):
        acc = 0.0
        for k in range(K):
            acc += A[i, k] * B[k, j]
        C[i, j] = acc

这个版本容易理解,但每个 C 元素都重复读取 A 的一行和 B 的一列,内存复用差。

慢点在哪里

  • A/B 数据被反复从 global memory 读取。
  • B 的列访问在 row-major 布局下不连续,容易不 coalesced。
  • 每次只算一个 C 元素,寄存器和并行度利用不充分。
  • 没有把 K 方向切块,cache/shared memory 复用很弱。

4. Tiling:把大问题切成小块

高性能 GEMM 的第一原则是切块。一个 thread block 或 Triton program 负责一块 C_tile[BLOCK_M, BLOCK_N],沿 K 分段加载 A_tile[BLOCK_M, BLOCK_K]B_tile[BLOCK_K, BLOCK_N],累加到寄存器里的 accumulator。

蓝色是当前 A tile,橙色是当前 B tile,绿色是正在累加的 C tile。 K tile 每前进一步,C tile 上的 accumulator 多累加一段部分和。

A tiles
×
B tiles
C tile

BLOCK_M × BLOCK_N

决定一个 program/block 输出多少 C 元素。越大复用越好,但寄存器压力也越大。

BLOCK_K

决定每次从 K 方向加载多厚的一片。太小搬运次数多,太大 shared/register 压力高。

Accumulator

中间结果通常保存在寄存器里的 FP32 accumulator,最后再写回 C。

5. 内存层级:优化就是减少慢内存访问

GPU 的算力很强,global memory 却相对慢。GEMM 要快,必须把 A/B tile 搬到更快的层级, 并让每次搬运服务尽可能多的乘加。

Global memory
慢 / 大
Shared memory
中 / 片上
Registers
快 / 小

Coalescing

让相邻线程读相邻地址,减少 memory transaction。

Shared reuse

一个 A tile 被多个 N 方向输出复用,一个 B tile 被多个 M 方向输出复用。

Register tiling

每个线程算多个 C 元素,accumulator 放寄存器。

Double buffering

一边算当前 tile,一边预取下一块 tile,隐藏内存延迟。

6. 常用优化方法总表

方法 解决的问题 关键取舍
Tiling / Blocking 减少 A/B 反复从 global memory 读取。 tile 太小复用差,太大寄存器/shared memory 压力高。
Shared memory staging 把 global memory 的慢访问变成片上复用。 需要处理 bank conflict、同步和容量限制。
Register tiling 减少 C 中间结果读写,提高每线程算术强度。 寄存器用太多会降低 occupancy。
Vectorized / coalesced loads 提高内存吞吐,减少零散访存。 依赖对齐、连续布局和边界处理。
Loop unrolling 减少循环控制开销,暴露更多指令级并行。 代码尺寸和寄存器压力可能增加。
Tensor Core / MMA 利用专用矩阵乘加单元获得数量级吞吐提升。 需要合适 dtype、layout、tile shape 和累加精度。
Double buffering / pipelining 用计算隐藏下一块数据加载延迟。 实现更复杂,占用更多 shared memory/register。
Epilogue fusion 把 bias、activation、scale 等后处理合进 GEMM 写回前。 减少额外 kernel 和 memory round trip。
Autotuning 不同 M/N/K、dtype、GPU 上最佳 tile 不同。 用搜索换性能,Triton 特别常用。

7. CUDA 写法:从 block/thread 映射理解

CUDA 版本通常显式控制 thread block、shared memory、同步和每个线程负责的输出元素。 下面是概念版,不追求完整可编译,重点看数据如何移动。

朴素 CUDA:每线程一个 C 元素

__global__ void matmul_naive(A, B, C, M, N, K) {
  int row = blockIdx.y * blockDim.y + threadIdx.y;
  int col = blockIdx.x * blockDim.x + threadIdx.x;

  float acc = 0.0f;
  for (int k = 0; k < K; ++k) {
    acc += A[row * K + k] * B[k * N + col];
  }
  C[row * N + col] = acc;
}

简单,但 A/B 全靠 global memory,复用差。

Shared memory tiled CUDA

for (int kt = 0; kt < K; kt += TILE_K) {
  As[ty][tx] = A[row, kt + tx];
  Bs[ty][tx] = B[kt + ty, col];
  __syncthreads();

  for (int kk = 0; kk < TILE_K; ++kk) {
    acc += As[ty][kk] * Bs[kk][tx];
  }
  __syncthreads();
}
C[row, col] = acc;

A/B tile 先进入 shared memory,然后被 block 内多个线程复用。

CUDA 优化时的三问

  • 线程访问 global memory 是否 coalesced?B 的访问是否需要转置、重排或向量化?
  • shared memory 是否有 bank conflict?是否需要 padding?
  • 每个线程算几个 C 元素?寄存器压力和 occupancy 是否平衡?

8. Triton 写法:用 block program 表达 GEMM

Triton 更像是在写“一个 program 负责一个 C tile”。你不用手动写每个线程, 但仍然要选择 BLOCK_MBLOCK_NBLOCK_K, 并理解 mask、stride、accumulator 和 autotune。

Triton matmul 骨架

@triton.jit
def matmul_kernel(A, B, C, M, N, K,
                  stride_am, stride_ak,
                  stride_bk, stride_bn,
                  stride_cm, stride_cn,
                  BLOCK_M: tl.constexpr,
                  BLOCK_N: tl.constexpr,
                  BLOCK_K: tl.constexpr):
    pid = tl.program_id(0)
    pid_m = pid // tl.cdiv(N, BLOCK_N)
    pid_n = pid %  tl.cdiv(N, BLOCK_N)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)
    for k0 in range(0, K, BLOCK_K):
        a = tl.load(A + offs_m[:, None] * stride_am
                      + (k0 + offs_k[None, :]) * stride_ak,
                    mask=(offs_m[:, None] < M) & (k0 + offs_k[None, :] < K))
        b = tl.load(B + (k0 + offs_k[:, None]) * stride_bk
                      + offs_n[None, :] * stride_bn,
                    mask=(k0 + offs_k[:, None] < K) & (offs_n[None, :] < N))
        acc += tl.dot(a, b)

    tl.store(C + offs_m[:, None] * stride_cm
               + offs_n[None, :] * stride_cn,
             acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

读 Triton kernel 的关键

  • pid_m/pid_n 决定当前 program 负责哪个 C tile。
  • offs_m/offs_n/offs_k 构造 block 内的二维地址。
  • tl.load 一次加载 A/B tile,边界靠 mask 处理。
  • tl.dot 把 tile 乘加到 FP32 accumulator。
  • tl.store 最后一次写回 C,可在此融合 bias/activation。

常见配置

BLOCK_M/N 常见 16、32、64、128;BLOCK_K 常见 32、64。

Autotune

同一 kernel 用多组 block size、num_warps、num_stages 搜索最优。

Tensor Core

当 dtype 和 shape 合适时,tl.dot 会走 MMA/Tensor Core 路径。

9. 从慢到快的学习路线

1

写出朴素三重循环

目标是完全理解 C[i,j] 的点积过程和 M/N/K 三个维度。

2

实现 tiled 版本

把 C 切成 tile,沿 K 分段累加。先不追求最快,只追求正确理解数据复用。

3

观察内存访问

检查 global load 是否连续、shared memory 是否复用、C 是否只写一次。

4

加入寄存器 tiling 和 pipelining

每个线程负责多个输出元素,并用 double buffering 隐藏下一块 tile 的加载。

5

迁移到 Tensor Core / Triton autotune

tl.dot 或 CUDA MMA 指令承担 tile 乘加,用 autotune 找适合当前 shape 的参数。

10. 优化检查清单

性能视角

  • 是不是 memory-bound?算术强度是否足够?
  • A/B tile 是否被多个 C 元素复用?
  • global load 是否 coalesced,是否可以 vectorize?
  • shared memory 是否出现 bank conflict?
  • 寄存器是否太多导致 occupancy 下降?

正确性视角

  • 边界 M/N/K 不是 block size 倍数时 mask 是否正确?
  • FP16/BF16 输入是否使用 FP32 accumulator?
  • 转置、stride、layout 是否和真实张量一致?
  • 融合 bias/activation 后数值是否和 reference 对齐?
  • 不同 batch/shape 下 autotune 参数是否仍然安全?

关键词索引

GEMM Tiling Shared Memory Register Tiling Coalesced Load Bank Conflict Tensor Core MMA / WMMA Triton tl.dot Autotune Epilogue Fusion