Full learning note with source calibration

FlashAttention vs SageAttention

先讲标准 Attention 的瓶颈, 再串起 FlashAttention v1/v2/v3/v4 的 IO-aware 路线,最后用 SageAttention 2.2.0 源码校准 INT8/FP8/FP4 量化路线、GPU 后端、功能边界和选型决策。

Note: FlashAttention vs SageAttention Repo: thu-ml/SageAttention Subdir: SageAttention Commit: 3647690 Package: sageattention 2.2.0
IO-aware INT8 QK FP16 / FP8 / FP4
Q / Query
per-block / per-warp / per-thread scale
K / Key
smooth_k fused subtract mean
V / Value
FP16 on SM80/86, FP8 on SM89+

核心结论

两条路线解决同一个 Attention 瓶颈:FlashAttention 优先减少 HBM IO 并保持数学等价;SageAttention 优先降低 QK/PV 计算精度以换取更高 tensor core 吞吐。

IO

FlashAttention 不物化 N×N

FlashAttention 的核心是 tiling + online softmax:在 SRAM 中分块计算 QK、softmax、PV,只把 Q/K/V/O/LSE 等 O(N) 状态留在 HBM, 避免存储 S/P 两个 N×N 矩阵。

8-bit

Q/K 统一走 INT8

主包的所有 SageAttention2/2++ 前向路径都会把 Q/K 量化为 INT8,并携带 FP32 scale 进入 attention kernel。 CUDA 量化入口在 quant.py:22quant.py:105

SM

按 GPU 架构自动分派

sageattn() 读取当前设备 compute capability,SM80/86/89/90/120/121 分别进入不同后端; 分派逻辑集中在 core.py:143

FP8

V 的精度随后端变化

SM80/SM86 使用 FP16 PV,SM89/SM90/SM120/SM121 使用 FP8 V 路径; FP8 V 会先转置、padding、permute,再做 per-channel 量化,入口在 quant.py:224

FP4

SageAttention3 是独立分支

Blackwell FP4 实现在 sageattention3_blackwell,入口是 sageattn3_blackwell(), 不是主包 sageattention.sageattn() 自动调度的一部分。

准确性边界:FlashAttention/FA4 部分来自论文与生态背景,用于复习技术路线;SageAttention 具体行为以本地源码 commit 3647690 为准。凡是“开发中、PR 中、版本号、硬件支持状态”这类会快速变化的信息,生产选型前都需要重新核对上游。

Attention 的瓶颈

标准 Attention 的 FLOPs 是 O(N²d),但工程上更痛的是 N×N 中间矩阵和 HBM 往返。

标准计算流程

S = Q × K^T          # N × N scores
P = softmax(S)      # N × N probabilities
O = P × V           # N × d output

如果直接实现,S 和 P 都会成为 N×N 矩阵。序列越长,显存和带宽压力按平方增长;N=4096 时,单 head 的 FP16 attention matrix 已经是数十 MB 量级。

两类优化路线

路线核心动作代价
FlashAttention 减少 HBM IO,不物化 N×N,online softmax 精确重缩放。 kernel 编排复杂,对 GPU 架构高度适配。
SageAttention Q/K INT8、V FP8/FP16/FP4,利用低比特 tensor core 吞吐。 引入量化误差,需要 smoothing、细粒度 scale 和两级累加控制精度。

FlashAttention 原理

FlashAttention 的关键不是减少理论 FLOPs,而是让 Q/K/V 小块在 SRAM 中完成计算,避免反复读写 N×N 的 score/probability。

IO-Aware Tiling

Q、K、V 被切成适合 SRAM 的 tile。每个 Q tile 依次扫描 K/V tile,在片上完成局部 score、softmax 统计量更新和输出累加。 HBM 只保留输入、输出和每行 LSE 这样的 O(N) 状态。

Online Softmax

softmax 的全局分母通过增量维护行最大值 m 和指数和 l 完成。每处理一个 K block,如果新的最大值变大, 旧输出乘以 exp(m_old - m_new) 精确重缩放。

Backward Recomputation

前向不保存 P,而保存 O 与每行 LSE;反向按块重算 P_ij = exp(Q_i K_j^T / sqrt(d) - L_i)。 这牺牲一部分计算,换显存和 HBM 访问下降。

Online softmax 状态更新:
m_new = max(m_old, rowmax(S_ij))
scale = exp(m_old - m_new)
l_new = l_old * scale + rowsum(exp(S_ij - m_new))
O_new = O_old * scale + exp(S_ij - m_new) @ V_j
O_final = O_new / l_new

FlashAttention 演进

从 FA1 到 FA4,主线是把非矩阵乘开销、同步、数据搬运和硬件异步能力一点点压进更深的流水线。

FA1 · 2022

IO-aware attention

用 tiling + online softmax 避免保存 S/P,证明在合理 SRAM 大小下可以达到更优 HBM IO 复杂度。 这是 FlashAttention 系列的“数学骨架”。

FA2 · 2023

并行度与 warp 分区优化

交换循环顺序,提高长序列下的 SM 利用率;让 warp 持有完整行以减少 warp 间通信;延迟归一化,减少非矩阵乘 FLOPs。

FA3 · 2024

Hopper: WGMMA + TMA + Warp Specialization

使用异步 WGMMA 做矩阵乘、TMA 做 HBM/SRAM 搬运,生产者/消费者 warp 特化。Ping-pong 调度让一个 warpgroup 做 GEMM 时,另一个处理 softmax。

FA4 · 2026 背景

Blackwell: 更深流水线与可编程稀疏

笔记中的 FA4 信息应视为论文/生态背景:面向 Blackwell 的异步 MMA、TMEM、软件 exp、条件重缩放、CuTe DSL 和 FlexAttention 集成。 这些不是本 SageAttention 仓库源码实现。

FA4 复习要点

这部分来自论文梳理,用来理解为什么 B200/SM100 上 FlashAttention 路线仍然很强。

5 路 Warp 特化

将 FA4 拆成 load、MMA、softmax、correction、epilogue 多类 warp。核心意图是把完全异步的 tensor core 工作、softmax 和数据搬运重叠起来, 防止 Blackwell 上 SFU/共享内存等非对称资源成为瓶颈。

软件 exp 与条件重缩放

当硬件指数单元跟不上 tensor core 吞吐时,FA4 用 FMA 上的多项式近似实现部分 exp2, 并只在数值稳定性需要时重缩放中间输出,减少 online softmax 的校正次数。

TMEM 与 2-CTA MMA

Blackwell 的 Tensor Memory 可作为 tensor core 直连累加空间。2-CTA MMA 可以扩大 tile 并减少部分反向路径中的冗余数据流和原子归约压力。

CuTe DSL

强调 FA4 从 CUDA C++ 模板转向 Python 侧 CuTe DSL,降低 kernel 变体开发和编译迭代成本。这一点属于 FA4 生态实现,不是 SageAttention 当前源码。

FlexAttention / Block-Sparse

score_mod 与 mask_mod 可以把 ALiBi、softcapping、sliding window、document mask 等模式在编译时注入 kernel。 对稀疏 attention,关键是跳过 Empty block、简化 Full block、只对 Partial block 应用 mask。

准确性提醒

FA4 版本号、PR、GPU 支持状态和具体 TFLOPS 数字都属于上游快速变化内容。HTML 保留这些作为复习线索, 但不把它们当作本 SageAttention 仓库可证明的源码事实。

架构图

从 Python API 到 CUDA/Triton kernel,关键分层可以拆成 API 层、调度层、量化层、attention kernel 层和可选 Blackwell FP4 包。

主包 SageAttention2/2++

源码包名为 sageattention,版本在 setup.py:272 声明为 2.2.0。README 明确列出 Ampere/Ada/Hopper 优化、QK INT8、PV FP8、两级累加、torch.compile 非 cudagraph 与分布式推理支持, 见 README.md:23

Blackwell FP4 分支

sageattention3_blackwell 构建独立包 sageattn3,会编译 fp4attn_cudafp4quant_cuda。 Python 入口先预处理 Q/K/V,再打包 FP4 与 FP8 scale,最后调用 C++ CUDA kernel;构建脚本要求 CUDA 12.8+, 见 sageattention3_blackwell/setup.py:60

执行流水线

这里以主包 sageattention.sageattn() 和 SM89/SM90 类 FP8 后端为代表,串起最常见的推理路径。

输入校验与布局解释

API 支持 HNDNHD 两种布局,Q/K/V 必须在同一 CUDA 设备且 dtype 一致; 主实现只接受 FP16/BF16 输入,head_dim 会被 padding 到 64 或 128,大于 128 直接拒绝。 见 core.py:724

GPU 架构分派

get_cuda_arch_versions() 读取所有 CUDA 设备架构,sageattn() 用当前 Q 所在设备索引选出分支。 SM89 默认进入 sageattn_qk_int8_pv_fp8_cuda(..., pv_accum_dtype="fp32+fp16"), SM90 进入专用 sageattn_qk_int8_pv_fp8_cuda_sm90(..., pv_accum_dtype="fp32+fp32")

smooth_k 与 LSE 校正

开启 smooth_k 时,代码先沿序列维度计算 K 的均值,GQA 场景会把 KV head 的均值广播到 Q head。 如果用户请求 return_lse,还会计算 q @ mean(K) 形式的校正项,返回时加回 core.py:772

Q/K INT8 量化

CUDA 后端的 per_warp_int8 将 Q 的 scale 做到 warp 粒度,K 仍以 block 粒度为主; per_thread_int8 是 Triton kernel,Q scale 扩展到每个 thread lane 分组。SM120/121 自动分支强制使用 per_warp

V 的 FP16/FP8 准备

SM80/SM86 走 FP16 PV;SM89/SM90/SM120/SM121 走 FP8 V。FP8 V 量化时会做 transpose、padding 到 64 对齐、固定顺序 permute、per-channel scale。 当 SM89 使用 pv_accum_dtype="fp32+fp16" 时,scale_max 从 448 降到 2.25 以适配 FP16 二级累加。

attention kernel 前向

kernel 输入包含 INT8 Q/K、FP16 或 FP8 V、scale、layout、causal 标志、量化粒度和 softmax scale。 SM90 C++ kernel 内部将 sm_scale 乘以 log2(e),维护 online softmax 的 m/d 状态,并通过 TMA tensor map 加载 Q/K/V。

后端矩阵

源码里的“自动选择”其实是非常明确的架构表;手动调用显式函数时,也应按这张表理解精度与限制。

GPU 架构 sageattn() 分支 Q/K 路径 V/PV 路径 源码确认点
SM80
A100/A800 类
sageattn_qk_int8_pv_fp16_cuda INT8,默认 per_thread FP16 V,默认 FP32 PV 累加 core.py:144;CUDA extension 暴露 FP16/FP32 累加变体。
SM86
RTX 3090 类
sageattn_qk_int8_pv_fp16_triton INT8 per-block,Triton FP16 PV,Triton attention core.py:146;非 causal 路径支持 attn_mask
SM89
Ada / RTX 4090 / L20
sageattn_qk_int8_pv_fp8_cuda INT8,默认 per_thread FP8 V;默认 fp32+fp16 两级累加 core.py:148;SM89 pybind 暴露 FP8/fuse scale/inst buffer 变体。
SM90
H100/H20 类
sageattn_qk_int8_pv_fp8_cuda_sm90 INT8,默认 per_thread FP8 V;只实现 fp32+fp32 推荐路径 core.py:985fp32 直接抛 NotImplementedError
SM120/121
RTX 5090 / Blackwell 消费级
sageattn_qk_int8_pv_fp8_cuda INT8,强制 per_warp FP8 V;fp32+fp16 core.py:152 注释说明 Triton kernel 当前不可用。
SM100
B200/GB200
sageattn() 未分派 主包 setup 可编译 SM100 标志,但 Python 自动调度不覆盖 主包无 SM100 自动路径 setup.py:156 有 SM100 编译标志;core.py:157 对未知 arch 抛错。

API/调度:自动入口很薄,真实选择写死在 core.py

sageattn() 的签名接收 **kwargs,但当前自动入口没有把这些 kwargs 转发给底层显式函数。 这意味着如果需要手动指定 qk_quant_granpv_accum_dtypesmooth_v 等参数,应该直接调用显式后端函数。

sageattn_varlen() 是单独的变长路径:它使用 Triton per-block INT8 量化和 FP16 PV,不进入 SM89/SM90 FP8 CUDA 后端。 README 只说支持 varying sequence lengths,但源码里 smooth_k 对 varlen 的均值是按全部 token 维度计算,注释也说明“按每个序列单独计算需要专用 kernel”, 见 core.py:432

量化模块:scale 的形状决定 kernel 的读取方式

per_block_int8 为 Q/K 按 block 生成 scale;per_warp_int8 将 Q scale 拆到 warp 粒度; per_thread_int8 的 Triton kernel 对 Q 生成 *8 的 lane 分组 scale,对 K 生成 *4 的分组 scale。 SM90 C++ kernel 在 scale shape 检查处明确区分 per_warp 与 per_thread。

V 的 FP8 量化不是简单 cast:源码先调用 fused kernel 做 transpose/pad/permute,再用 E4M3 FP8 输出和 FP32 scale。 smooth_v 只在部分 FP32 累加路径有意义;两级累加路径会发出 warning 并忽略。

CUDA kernel:SM80、SM89、SM90 的差异主要在 PV dtype 与硬件路径

SM80 extension 暴露 qk_int8_sv_f16_accum_f16_attnaccum_f32inst_buffuse_v_mean 变体。 SM89 extension 暴露 FP8 V scale 融合与两级累加变体。 SM90 extension 只暴露 FP8 inst buffer 路径,Python 层也拒绝普通 fp32 累加模式。

SM90 的 C++ kernel 使用 CUtensorMap 创建 Q/K/V 的 TMA 加载映射,并以 CTA_Q=64CTA_K=128NUM_THREADS=128 发射 kernel。

SageAttention3/Blackwell:源码是 FP4 包,不是主包增强版

sageattn3_blackwell README 要求 Python、Torch、CUDA 较新的环境,并要求从源码编译。 Python API 里会对 K 做均值平滑、对 Q 做 per-block 或全局均值平滑,并计算 delta_s 作为校正项; 随后 Q/K/V 都被 packed 成 FP4,scale 用 FP8 E4M3 存储。

一个关键源码事实:C++ 入口检查当前 GPU 必须是 SM120 或 SM121,并报错文案写着 “Blackwell GPUs or newer”。 这和主包 sageattn() 的 SM120/121 INT8+FP8 路径是两条独立路线,见 blackwell/api.cu:219

FlashAttention 对照:作为选型背景,而非本仓库实现

关于 FA1/2/3/4 的内容适合作为工程选型背景:FlashAttention 的核心是减少 HBM IO、保持 attention 数值形式更接近原计算,并提供训练/反向和 KV cache 生态能力。 但本页的源码事实只覆盖 SageAttention 仓库;FA4/B200 的具体能力不应从这个源码树推断。

因此对比时可以简化为:训练、复杂 mask、paged KV、decode 场景优先考虑 FlashAttention 或框架内核;视觉/视频生成、长序列 prefill、可接受轻微量化误差的推理场景,SageAttention 更值得测试。

功能模块拆解

按源码目录理解模块,比按论文名理解更稳:主包其实是 Python dispatch + Triton kernels + C++ CUDA extensions 的组合。

sageattention/core.py

总控层。负责自动调度、显式后端 API、输入校验、head_dim padding、smooth_k/LSE 校正、V 量化调用和 CUDA extension 调用。

sageattention/quant.py

CUDA fused 量化封装层。包括 Q/K per-block INT8、Q per-warp INT8、V sub_mean、V per-channel FP8。

sageattention/triton/*

Triton 量化和 attention 前向。SM86 自动路径、sageattn_varlen、per-thread INT8 量化均依赖这里。

csrc/fused/*

fused CUDA 量化算子绑定。Python 中的 _fused.quant_per_block_int8_fuse_sub_mean_cuda 等接口来自这里。

csrc/qattn/*

架构专用 attention kernel。SM80 处理 INT8 QK + FP16 V,SM89/SM90 处理 INT8 QK + FP8 V。

sageattention3_blackwell/*

独立 FP4 包。包含 FP4 quantization、Blackwell kernel、CUTLASS/CuTe 相关适配和单独的 Python API。

精度与量化策略

SageAttention 的关键风险不在“是否用了 online softmax”,而在 QK 与 PV 的低比特表示如何控制误差。

机制 作用位置 源码行为 工程含义
smooth_k K 量化前 沿序列维度计算 K mean;CUDA block/warp 量化可 fused subtract mean。 降低 K outlier 对 INT8 scale 的污染;开启后 return_lse 需要加回校正项。
per_block Q/K INT8 Triton FP16 路径主要使用;scale 数量最少。 开销低但粒度粗,适合作为简单后端或变长路径。
per_warp Q INT8 CUDA fused Q quant;SM120/121 自动路径强制使用。 比 per-block 精细,兼顾 scale 开销与精度。
per_thread Q/K INT8 Triton kernel 为每个 lane 分组生成 scale;SM80/89/90 默认使用。 精度更细,但 scale 更多,kernel 需要匹配 scale 布局。
per_channel_fp8 V/PV V 转置、padding、permute 后输出 float8_e4m3fn 与 FP32 scale。 把 PV 路径放到 FP8 tensor core/WGMMA;量化和布局转换开销不能忽略。
两级累加 PV accumulation fp16+fp32fp32+fp16fp32+fp32 等 inst buffer 变体。 为长序列与 FP8/FP16 累加缓解精度损失,是 SageAttention2++ 重要工程点。

注意:README 的 TOPS 图表说明 attention kernel 速度不包含量化与 smoothing 开销。 因此端到端替换时,短序列、decode、小 batch 或量化开销占比高的场景,未必等同于 kernel 图上的收益。

核心技术对比

把精度、性能、内存、功能支持和 GPU 覆盖合并成一组复习表。SageAttention 列以当前源码为准,FA4 列按论文/生态背景理解。

计算精度路线

维度FlashAttentionSageAttention
QKT FP16/BF16,保持标准 attention 数学形式。 INT8 量化,使用 scale 还原分数,存在近似误差。
Softmax online softmax,等价于标准 softmax。 attention kernel 内也维护 online softmax,但输入分数来自量化 QK。
PV FP16/BF16;FA3/FA4 背景中还涉及 FP8 路线。 SM80/86 为 FP16 V;SM89/90/120/121 为 FP8 V;SA3 独立包为 FP4 packed Q/K/V。
误差来源 主要来自浮点舍入,与标准实现目标等价。 Q/K/V 低比特量化、scale 粒度、累加器精度和 smoothing 策略。

内存与额外缓冲

维度FlashAttentionSageAttention
前向显存 O(N),不保存 N×N S/P。 同样不应物化 N×N S/P,但会新增 INT8/FP8 tensor 与 scale 缓冲。
训练显存 保存 LSE 等状态,反向按块重算 P。 当前主包源码没有 backward API,不能作为训练 attention 直接替换。
端到端开销 主要看 kernel 调度、mask/KV cache 支持和架构适配。 需要把量化、smooth_k、V 转置/padding/permute 的额外开销算进去。
功能 FlashAttention / FA2/FA3 FA4 背景 当前 SageAttention 源码事实
Backward / 训练 支持 论文/生态目标支持 主包未暴露 backward
Causal 支持 支持背景 支持,但 Triton 路径中 causal 与 attn_mask 互斥。
Variable Length 支持 支持背景 sageattn_varlen,走 Triton INT8+FP16,不是 FP8 CUDA 路径。
GQA/MQA 支持 支持背景 支持约束num_qo_heads 必须可整除 num_kv_heads
Dropout 训练内核常见支持 看具体版本 主包 API 无 dropout 参数
Sliding Window / Paged KV 推理生态成熟 重要方向 主包未见主线 API
Attention Mask 支持 score/mask mod 背景 仅 Triton FP16 显式路径有通用 attn_mask
torch.compile 生态支持 custom op 背景 README 声明 non-cudagraphs 支持,源码里 custom_op 注册 fake impl。

选型与踩坑

训练/推理、模型类型、硬件平台判断;SageAttention 的能力边界以本地源码为准,FlashAttention/FA4 作为外部生态背景。

快速判断:训练、复杂 mask、paged KV、decode 为主、MLA/非标准 attention,优先 FlashAttention 或框架专用内核; 标准 MHA/GQA、head_dim ≤ 128、视觉/视频生成或长 prompt prefill、可接受低比特量化误差,SageAttention 值得压测。

更适合 SageAttention 的场景

视觉/视频生成、DiT/ViT 类大矩阵 attention、长序列 prefill、对轻微量化误差可容忍的推理。README 的示例主要围绕 CogVideoX、Mochi、HunyuanVideo、LTX 等视频/图像模型,并提示不是所有模型都适合直接 monkey patch F.scaled_dot_product_attention

更适合 FlashAttention/框架内核的场景

训练或需要 backward、复杂 KV cache、LLM decode、sliding window、paged KV、dropout、精度极敏感任务。当前主包源码没有暴露 backward kernel,也没有主线 paged KV cache API;这些场景不要仅凭 forward 替换做结论。

训练 vs 推理

场景推荐原因
训练 FlashAttention / 框架专用 kernel 需要 backward、dropout、复杂 mask 和稳定数值;当前 SageAttention 主包源码只暴露前向推理路径。
推理且精度敏感 先用 FlashAttention 做基线,再测试 SageAttention SageAttention 有 INT8/FP8/FP4 量化误差,必须用任务指标确认。
推理且视觉/视频生成 SageAttention 值得优先测试 原仓库示例集中在 Diffusion/视频生成;这类任务通常标准 attention 多、计算密集、对微小 attention 误差较宽容。

按模型类型判断

模型/阶段SageAttention 判断说明
Diffusion / DiT / 视频生成 适合 标准 MHA/GQA、head_dim 常为 64/128、推理为主,仓库 README 和 examples 也主要覆盖这类模型。
Vision Transformer 适合压测 分类/检测/分割任务通常可用任务指标评估轻微量化误差,前向推理收益可能较好。
LLM Prefill 可测试 Prefill 是大矩阵 attention,compute-bound 特征更明显;但自回归模型的输出质量需要仔细评估。
LLM Decode 收益有限 Q 只有少量 token,瓶颈常在 KV 读取和调度;FlashAttention/推理框架的 split-KV、paged KV 通常更成熟。
MLA / DeepSeek 类 absorbed attention 不建议直接套 结构非标准,等效 head_dim/数据流可能不满足 SageAttention 主包限制;应使用 vLLM/SGLang 等专用 MLA kernel。
原生 FP8 模型 看目标 FlashAttention3-FP8 避免反复转换;SageAttention 需要回到 FP16/BF16 输入再量化成 INT8/FP8,可能有额外开销。

按硬件平台判断

硬件复习结论准确性限定
A100 / SM80 训练用 FA2/框架内核;推理可测 SageAttention CUDA INT8+FP16。 SageAttention 自动分支为 SM80 FP16 PV,默认 FP32 累加。
RTX 3090 / SM86 SageAttention 自动走 Triton INT8+FP16。 没有 FP8 PV 自动路径,功能相对简单。
RTX 4090 / L20 / SM89 SageAttention2++ 的典型强项:INT8 QK + FP8 PV + fp32+fp16 需要 CUDA 12.4+ 构建 FP8 支持。
H100/H20 / SM90 Diffusion 推理可比较 SageAttention 与 FA3/FA3-FP8;训练和复杂 LLM 能力优先 FA3/FA4/框架。 SageAttention SM90 只建议 fp32+fp32 路径;普通 fp32 未实现。
B200/GB200 / SM100 倾向 FA4;本地 SageAttention 主包没有 SM100 自动分派。 主包 setup 有 SM100 编译标志,但公开 sageattn() 进入 unknown arch 抛错;SA3 本地 C++ 入口也只放行 SM120/121。
RTX 5090 / SM120/121 SageAttention 主包自动走 INT8+FP8 per_warp;SA3 FP4 包也面向 SM120/121。 需要 CUDA 12.8+;具体端到端收益仍要用模型实测。

压缩版决策树

是否训练?
  是  -> FlashAttention / 框架专用 kernel
  否  -> 继续

是否 B200/SM100?
  是  -> FA4 或框架内核;当前本地 SageAttention 主包无 sageattn() 自动路径
  否  -> 继续

是否标准 MHA/GQA 且 head_dim <= 128?
  否  -> FlashAttention / 模型专用内核
  是  -> 继续

是否视觉/视频生成、ViT、长 prompt prefill?
  是  -> SageAttention 值得优先压测
  否  -> 继续

是否 decode、paged KV、sliding window、block-sparse、自定义 score/mask?
  是  -> FlashAttention / FA4 / 推理框架内核
  否  -> 用 FlashAttention 做基线,再比较 SageAttention 端到端质量和吞吐

参数转发:sageattn() 接收 **kwargs 但不传下去;需要改 pv_accum_dtypesmooth_v 时,请调用显式后端函数。

layout:HNDNHD 只改变 shape 解释,不会自动重排输入;传错 layout 会让 head/seq 维度完全错位。

head_dim:源码只支持 padding 到 64/128;原始 head_dim 大于 128 会直接抛错。

SM100/B200:主包 setup 有 SM100 编译标志,但 sageattn() 没有 SM100 分派;Blackwell FP4 包的 C++ 入口又检查 SM120/121。 这是一处“能编译部分标志”和“公开入口可运行”并不等价的边界。

变长 smooth_k:sageattn_varlen() 对 K mean 的注释说明它按全部 batch token 计算,不是每个序列各自计算。

causal mask:Triton FP16 路径里 causal 与 attn_mask 互斥;CUDA 显式路径没有通用 attn_mask 参数。

源码走读顺序

按这个顺序读,能最快把“API 怎么进来、数据怎么变、kernel 怎么发射”串起来。

1. 公共导出

sageattention/__init__.py:1 看包对外暴露哪些 API。

2. 自动分派

core.py:79sageattn() 的参数、文档和 SM 分支。

3. Triton FP16 路径

core.py:160 看 SM86/Triton 的量化、mask 和 return_lse。

4. 变长路径

core.py:334sageattn_varlen() 的 cu_seqlens 与 per-block 量化。

5. SM80 CUDA FP16 PV

core.py:451 看 FP16 PV 的三种累加选择。

6. SM89 CUDA FP8 PV

core.py:636 看 FP8 V 量化与两级累加。

7. SM90 专用路径

core.py:829 看 Hopper WGMMA/TMA 后端入口。

8. CUDA 量化封装

quant.py:22quant.py:224 串起 INT8/FP8 准备。

9. fused pybind

csrc/fused/pybind.cpp:21 看 Python 调用的 fused CUDA 符号。

10. SM90 C++ kernel

qk_int_sv_f8_cuda_sm90.cu:126 看 TMA、scale index、online softmax 状态。

11. 构建与 CUDA 限制

setup.py:129 看 CUDA 版本和架构编译条件。

12. SageAttention3 FP4

sageattn3/api.py:75 看 Blackwell FP4 预处理和量化调用。

准确性校准

这一节专门把容易随时间变化、或和当前源码边界不同的表述校准成复习时更安全的说法。

主题 复习时的安全表述 依据/边界
SA2/SA3 与 SM100/B200 当前本地 SageAttention 主包 sageattn() 没有 SM100 自动分派;SA3 Blackwell 包当前 C++ 入口检查 SM120/121。 主包 core.py:143 分派缺少 SM100;SA3 api.cu:219 检查 SM120/121。
“SageAttention 无 backward” 对当前主包源码可以这样说:公开 API 和 pybind 主要是 forward attention,没有训练 backward 替换路径。 不要把它外推成论文永远不研究训练;SA3 论文标题包含 8-bit training exploration,但本地公开入口仍是推理前向。
FA4 版本/PR/性能数字 作为论文和生态背景复习,不作为本 SageAttention 仓库事实。实际项目中要重新查上游 release、README、benchmark。 这类信息高度时间敏感,本文只保留技术方向:Blackwell、TMEM/UMMA、CuTe DSL、FlexAttention、block-sparse。
H100 上 SageAttention vs FA3-FP8 可作为 Diffusion 推理候选,但必须用目标模型端到端质量和吞吐实测。 SageAttention README 的示例和 TOPS 图强调 kernel 速度,且 TOPS 不含量化/smoothing 开销。
LLM Prefill/Decode Prefill 更可能从低比特矩阵乘受益;Decode 常受 KV 读、调度和 cache 机制限制,FlashAttention/推理框架更常见。 这是工程推断,不是本 SageAttention 仓库提供的 LLM decode 专用 API。
直接替换 SDPA 可以快速试验,但不应当作为所有模型的最终集成方式。 README 明确提示并非所有模型都适合 F.scaled_dot_product_attention = sageattn,图像/视频模型建议改 Attention class。

未来趋势

最后的趋势可以总结为:IO-aware、低比特量化、可编程稀疏和推理框架动态调度会继续融合。

FlashAttention 方向

FA4 代表的是架构深度绑定的 exact attention 路线:利用 Blackwell 的异步 MMA、TMEM、CuTe DSL 和 FlexAttention, 在保持数值形式更稳定的同时扩展 block-sparse、score/mask modifier 和训练能力。

SageAttention 方向

SageAttention2/2++ 证明了 INT8 QK + FP8 PV 在推理中可以获得很强吞吐;SageAttention3 进一步探索 FP4 和训练相关问题。 但当前本地源码中,SA3 是独立 Blackwell FP4 包,主包和 SA3 的硬件入口边界要分开记。

推理引擎融合

对 vLLM、SGLang 这类引擎,未来更现实的路线不是单一 attention kernel 通吃,而是按 GPU、模型结构、prefill/decode 阶段、KV cache 和精度目标动态选择后端。

一句话复习:
B200 / 训练 / 复杂 mask      -> FlashAttention / FA4 / 框架内核
视觉视频推理 / 标准 MHA/GQA  -> SageAttention 值得优先压测
LLM prefill                 -> 两者都测,关注质量和端到端吞吐
LLM decode / paged KV       -> FlashAttention 或推理框架专用内核
SM100 上当前本地 SageAttention -> 不要假设 sageattn() 可直接跑

参考文献与复习入口

参考文献列表,并补充源码入口。外部论文链接用于理解技术路线,源码链接用于确认当前实现。

学习检查清单

读完源码后,可以用这些问题自测是否真正抓住实现骨架。