FlashAttention 不物化 N×N
FlashAttention 的核心是 tiling + online softmax:在 SRAM 中分块计算 QK、softmax、PV,只把 Q/K/V/O/LSE 等 O(N) 状态留在 HBM, 避免存储 S/P 两个 N×N 矩阵。
Full learning note with source calibration
先讲标准 Attention 的瓶颈, 再串起 FlashAttention v1/v2/v3/v4 的 IO-aware 路线,最后用 SageAttention 2.2.0 源码校准 INT8/FP8/FP4 量化路线、GPU 后端、功能边界和选型决策。
两条路线解决同一个 Attention 瓶颈:FlashAttention 优先减少 HBM IO 并保持数学等价;SageAttention 优先降低 QK/PV 计算精度以换取更高 tensor core 吞吐。
FlashAttention 的核心是 tiling + online softmax:在 SRAM 中分块计算 QK、softmax、PV,只把 Q/K/V/O/LSE 等 O(N) 状态留在 HBM, 避免存储 S/P 两个 N×N 矩阵。
主包的所有 SageAttention2/2++ 前向路径都会把 Q/K 量化为 INT8,并携带 FP32 scale 进入 attention kernel。 CUDA 量化入口在 quant.py:22 与 quant.py:105。
sageattn() 读取当前设备 compute capability,SM80/86/89/90/120/121 分别进入不同后端;
分派逻辑集中在 core.py:143。
SM80/SM86 使用 FP16 PV,SM89/SM90/SM120/SM121 使用 FP8 V 路径; FP8 V 会先转置、padding、permute,再做 per-channel 量化,入口在 quant.py:224。
Blackwell FP4 实现在 sageattention3_blackwell,入口是 sageattn3_blackwell(),
不是主包 sageattention.sageattn() 自动调度的一部分。
准确性边界:FlashAttention/FA4 部分来自论文与生态背景,用于复习技术路线;SageAttention 具体行为以本地源码 commit
3647690 为准。凡是“开发中、PR 中、版本号、硬件支持状态”这类会快速变化的信息,生产选型前都需要重新核对上游。
标准 Attention 的 FLOPs 是 O(N²d),但工程上更痛的是 N×N 中间矩阵和 HBM 往返。
S = Q × K^T # N × N scores
P = softmax(S) # N × N probabilities
O = P × V # N × d output
如果直接实现,S 和 P 都会成为 N×N 矩阵。序列越长,显存和带宽压力按平方增长;N=4096 时,单 head 的 FP16 attention matrix 已经是数十 MB 量级。
| 路线 | 核心动作 | 代价 |
|---|---|---|
| FlashAttention | 减少 HBM IO,不物化 N×N,online softmax 精确重缩放。 | kernel 编排复杂,对 GPU 架构高度适配。 |
| SageAttention | Q/K INT8、V FP8/FP16/FP4,利用低比特 tensor core 吞吐。 | 引入量化误差,需要 smoothing、细粒度 scale 和两级累加控制精度。 |
FlashAttention 的关键不是减少理论 FLOPs,而是让 Q/K/V 小块在 SRAM 中完成计算,避免反复读写 N×N 的 score/probability。
Q、K、V 被切成适合 SRAM 的 tile。每个 Q tile 依次扫描 K/V tile,在片上完成局部 score、softmax 统计量更新和输出累加。 HBM 只保留输入、输出和每行 LSE 这样的 O(N) 状态。
softmax 的全局分母通过增量维护行最大值 m 和指数和 l 完成。每处理一个 K block,如果新的最大值变大,
旧输出乘以 exp(m_old - m_new) 精确重缩放。
前向不保存 P,而保存 O 与每行 LSE;反向按块重算 P_ij = exp(Q_i K_j^T / sqrt(d) - L_i)。
这牺牲一部分计算,换显存和 HBM 访问下降。
m_new = max(m_old, rowmax(S_ij))
scale = exp(m_old - m_new)
l_new = l_old * scale + rowsum(exp(S_ij - m_new))
O_new = O_old * scale + exp(S_ij - m_new) @ V_j
O_final = O_new / l_new
从 FA1 到 FA4,主线是把非矩阵乘开销、同步、数据搬运和硬件异步能力一点点压进更深的流水线。
用 tiling + online softmax 避免保存 S/P,证明在合理 SRAM 大小下可以达到更优 HBM IO 复杂度。 这是 FlashAttention 系列的“数学骨架”。
交换循环顺序,提高长序列下的 SM 利用率;让 warp 持有完整行以减少 warp 间通信;延迟归一化,减少非矩阵乘 FLOPs。
使用异步 WGMMA 做矩阵乘、TMA 做 HBM/SRAM 搬运,生产者/消费者 warp 特化。Ping-pong 调度让一个 warpgroup 做 GEMM 时,另一个处理 softmax。
笔记中的 FA4 信息应视为论文/生态背景:面向 Blackwell 的异步 MMA、TMEM、软件 exp、条件重缩放、CuTe DSL 和 FlexAttention 集成。 这些不是本 SageAttention 仓库源码实现。
这部分来自论文梳理,用来理解为什么 B200/SM100 上 FlashAttention 路线仍然很强。
将 FA4 拆成 load、MMA、softmax、correction、epilogue 多类 warp。核心意图是把完全异步的 tensor core 工作、softmax 和数据搬运重叠起来, 防止 Blackwell 上 SFU/共享内存等非对称资源成为瓶颈。
当硬件指数单元跟不上 tensor core 吞吐时,FA4 用 FMA 上的多项式近似实现部分 exp2,
并只在数值稳定性需要时重缩放中间输出,减少 online softmax 的校正次数。
Blackwell 的 Tensor Memory 可作为 tensor core 直连累加空间。2-CTA MMA 可以扩大 tile 并减少部分反向路径中的冗余数据流和原子归约压力。
强调 FA4 从 CUDA C++ 模板转向 Python 侧 CuTe DSL,降低 kernel 变体开发和编译迭代成本。这一点属于 FA4 生态实现,不是 SageAttention 当前源码。
score_mod 与 mask_mod 可以把 ALiBi、softcapping、sliding window、document mask 等模式在编译时注入 kernel。 对稀疏 attention,关键是跳过 Empty block、简化 Full block、只对 Partial block 应用 mask。
FA4 版本号、PR、GPU 支持状态和具体 TFLOPS 数字都属于上游快速变化内容。HTML 保留这些作为复习线索, 但不把它们当作本 SageAttention 仓库可证明的源码事实。
从 Python API 到 CUDA/Triton kernel,关键分层可以拆成 API 层、调度层、量化层、attention kernel 层和可选 Blackwell FP4 包。
sageattention.__init__ 暴露 sageattn、sageattn_varlen 以及几个显式后端函数。
见 __init__.py:1。
core.py 校验 dtype/device/layout/head_dim,按 SM 架构选择 FP16 或 FP8 PV 后端,并处理 smooth_k、LSE 校正。
源码包名为 sageattention,版本在 setup.py:272
声明为 2.2.0。README 明确列出 Ampere/Ada/Hopper 优化、QK INT8、PV FP8、两级累加、torch.compile 非 cudagraph 与分布式推理支持,
见 README.md:23。
sageattention3_blackwell 构建独立包 sageattn3,会编译 fp4attn_cuda 与 fp4quant_cuda。
Python 入口先预处理 Q/K/V,再打包 FP4 与 FP8 scale,最后调用 C++ CUDA kernel;构建脚本要求 CUDA 12.8+,
见 sageattention3_blackwell/setup.py:60。
这里以主包 sageattention.sageattn() 和 SM89/SM90 类 FP8 后端为代表,串起最常见的推理路径。
API 支持 HND 和 NHD 两种布局,Q/K/V 必须在同一 CUDA 设备且 dtype 一致;
主实现只接受 FP16/BF16 输入,head_dim 会被 padding 到 64 或 128,大于 128 直接拒绝。
见 core.py:724。
get_cuda_arch_versions() 读取所有 CUDA 设备架构,sageattn() 用当前 Q 所在设备索引选出分支。
SM89 默认进入 sageattn_qk_int8_pv_fp8_cuda(..., pv_accum_dtype="fp32+fp16"),
SM90 进入专用 sageattn_qk_int8_pv_fp8_cuda_sm90(..., pv_accum_dtype="fp32+fp32")。
开启 smooth_k 时,代码先沿序列维度计算 K 的均值,GQA 场景会把 KV head 的均值广播到 Q head。
如果用户请求 return_lse,还会计算 q @ mean(K) 形式的校正项,返回时加回
core.py:772。
CUDA 后端的 per_warp_int8 将 Q 的 scale 做到 warp 粒度,K 仍以 block 粒度为主;
per_thread_int8 是 Triton kernel,Q scale 扩展到每个 thread lane 分组。SM120/121 自动分支强制使用 per_warp。
SM80/SM86 走 FP16 PV;SM89/SM90/SM120/SM121 走 FP8 V。FP8 V 量化时会做 transpose、padding 到 64 对齐、固定顺序 permute、per-channel scale。
当 SM89 使用 pv_accum_dtype="fp32+fp16" 时,scale_max 从 448 降到 2.25 以适配 FP16 二级累加。
kernel 输入包含 INT8 Q/K、FP16 或 FP8 V、scale、layout、causal 标志、量化粒度和 softmax scale。
SM90 C++ kernel 内部将 sm_scale 乘以 log2(e),维护 online softmax 的 m/d 状态,并通过 TMA tensor map 加载 Q/K/V。
源码里的“自动选择”其实是非常明确的架构表;手动调用显式函数时,也应按这张表理解精度与限制。
| GPU 架构 | sageattn() 分支 |
Q/K 路径 | V/PV 路径 | 源码确认点 |
|---|---|---|---|---|
| SM80 A100/A800 类 |
sageattn_qk_int8_pv_fp16_cuda |
INT8,默认 per_thread |
FP16 V,默认 FP32 PV 累加 | core.py:144;CUDA extension 暴露 FP16/FP32 累加变体。 |
| SM86 RTX 3090 类 |
sageattn_qk_int8_pv_fp16_triton |
INT8 per-block,Triton | FP16 PV,Triton attention | core.py:146;非 causal 路径支持 attn_mask。 |
| SM89 Ada / RTX 4090 / L20 |
sageattn_qk_int8_pv_fp8_cuda |
INT8,默认 per_thread |
FP8 V;默认 fp32+fp16 两级累加 |
core.py:148;SM89 pybind 暴露 FP8/fuse scale/inst buffer 变体。 |
| SM90 H100/H20 类 |
sageattn_qk_int8_pv_fp8_cuda_sm90 |
INT8,默认 per_thread |
FP8 V;只实现 fp32+fp32 推荐路径 |
core.py:985 对 fp32 直接抛 NotImplementedError。 |
| SM120/121 RTX 5090 / Blackwell 消费级 |
sageattn_qk_int8_pv_fp8_cuda |
INT8,强制 per_warp |
FP8 V;fp32+fp16 |
core.py:152 注释说明 Triton kernel 当前不可用。 |
| SM100 B200/GB200 |
主 sageattn() 未分派 |
主包 setup 可编译 SM100 标志,但 Python 自动调度不覆盖 | 主包无 SM100 自动路径 | setup.py:156 有 SM100 编译标志;core.py:157 对未知 arch 抛错。 |
core.py
sageattn() 的签名接收 **kwargs,但当前自动入口没有把这些 kwargs 转发给底层显式函数。
这意味着如果需要手动指定 qk_quant_gran、pv_accum_dtype、smooth_v 等参数,应该直接调用显式后端函数。
sageattn_varlen() 是单独的变长路径:它使用 Triton per-block INT8 量化和 FP16 PV,不进入 SM89/SM90 FP8 CUDA 后端。
README 只说支持 varying sequence lengths,但源码里 smooth_k 对 varlen 的均值是按全部 token 维度计算,注释也说明“按每个序列单独计算需要专用 kernel”,
见 core.py:432。
per_block_int8 为 Q/K 按 block 生成 scale;per_warp_int8 将 Q scale 拆到 warp 粒度;
per_thread_int8 的 Triton kernel 对 Q 生成 *8 的 lane 分组 scale,对 K 生成 *4 的分组 scale。
SM90 C++ kernel 在 scale shape 检查处明确区分 per_warp 与 per_thread。
V 的 FP8 量化不是简单 cast:源码先调用 fused kernel 做 transpose/pad/permute,再用 E4M3 FP8 输出和 FP32 scale。
smooth_v 只在部分 FP32 累加路径有意义;两级累加路径会发出 warning 并忽略。
SM80 extension 暴露 qk_int8_sv_f16_accum_f16_attn、accum_f32、inst_buf 与 fuse_v_mean 变体。
SM89 extension 暴露 FP8 V scale 融合与两级累加变体。
SM90 extension 只暴露 FP8 inst buffer 路径,Python 层也拒绝普通 fp32 累加模式。
SM90 的 C++ kernel 使用 CUtensorMap 创建 Q/K/V 的 TMA 加载映射,并以 CTA_Q=64、CTA_K=128、NUM_THREADS=128 发射 kernel。
sageattn3_blackwell README 要求 Python、Torch、CUDA 较新的环境,并要求从源码编译。
Python API 里会对 K 做均值平滑、对 Q 做 per-block 或全局均值平滑,并计算 delta_s 作为校正项;
随后 Q/K/V 都被 packed 成 FP4,scale 用 FP8 E4M3 存储。
一个关键源码事实:C++ 入口检查当前 GPU 必须是 SM120 或 SM121,并报错文案写着 “Blackwell GPUs or newer”。
这和主包 sageattn() 的 SM120/121 INT8+FP8 路径是两条独立路线,见
blackwell/api.cu:219。
关于 FA1/2/3/4 的内容适合作为工程选型背景:FlashAttention 的核心是减少 HBM IO、保持 attention 数值形式更接近原计算,并提供训练/反向和 KV cache 生态能力。 但本页的源码事实只覆盖 SageAttention 仓库;FA4/B200 的具体能力不应从这个源码树推断。
因此对比时可以简化为:训练、复杂 mask、paged KV、decode 场景优先考虑 FlashAttention 或框架内核;视觉/视频生成、长序列 prefill、可接受轻微量化误差的推理场景,SageAttention 更值得测试。
按源码目录理解模块,比按论文名理解更稳:主包其实是 Python dispatch + Triton kernels + C++ CUDA extensions 的组合。
sageattention/core.py总控层。负责自动调度、显式后端 API、输入校验、head_dim padding、smooth_k/LSE 校正、V 量化调用和 CUDA extension 调用。
sageattention/quant.pyCUDA fused 量化封装层。包括 Q/K per-block INT8、Q per-warp INT8、V sub_mean、V per-channel FP8。
sageattention/triton/*
Triton 量化和 attention 前向。SM86 自动路径、sageattn_varlen、per-thread INT8 量化均依赖这里。
csrc/fused/*
fused CUDA 量化算子绑定。Python 中的 _fused.quant_per_block_int8_fuse_sub_mean_cuda 等接口来自这里。
csrc/qattn/*架构专用 attention kernel。SM80 处理 INT8 QK + FP16 V,SM89/SM90 处理 INT8 QK + FP8 V。
sageattention3_blackwell/*独立 FP4 包。包含 FP4 quantization、Blackwell kernel、CUTLASS/CuTe 相关适配和单独的 Python API。
SageAttention 的关键风险不在“是否用了 online softmax”,而在 QK 与 PV 的低比特表示如何控制误差。
| 机制 | 作用位置 | 源码行为 | 工程含义 |
|---|---|---|---|
smooth_k |
K 量化前 | 沿序列维度计算 K mean;CUDA block/warp 量化可 fused subtract mean。 | 降低 K outlier 对 INT8 scale 的污染;开启后 return_lse 需要加回校正项。 |
per_block |
Q/K INT8 | Triton FP16 路径主要使用;scale 数量最少。 | 开销低但粒度粗,适合作为简单后端或变长路径。 |
per_warp |
Q INT8 | CUDA fused Q quant;SM120/121 自动路径强制使用。 | 比 per-block 精细,兼顾 scale 开销与精度。 |
per_thread |
Q/K INT8 | Triton kernel 为每个 lane 分组生成 scale;SM80/89/90 默认使用。 | 精度更细,但 scale 更多,kernel 需要匹配 scale 布局。 |
per_channel_fp8 |
V/PV | V 转置、padding、permute 后输出 float8_e4m3fn 与 FP32 scale。 |
把 PV 路径放到 FP8 tensor core/WGMMA;量化和布局转换开销不能忽略。 |
| 两级累加 | PV accumulation | fp16+fp32、fp32+fp16、fp32+fp32 等 inst buffer 变体。 |
为长序列与 FP8/FP16 累加缓解精度损失,是 SageAttention2++ 重要工程点。 |
注意:README 的 TOPS 图表说明 attention kernel 速度不包含量化与 smoothing 开销。 因此端到端替换时,短序列、decode、小 batch 或量化开销占比高的场景,未必等同于 kernel 图上的收益。
把精度、性能、内存、功能支持和 GPU 覆盖合并成一组复习表。SageAttention 列以当前源码为准,FA4 列按论文/生态背景理解。
| 维度 | FlashAttention | SageAttention |
|---|---|---|
| QKT | FP16/BF16,保持标准 attention 数学形式。 | INT8 量化,使用 scale 还原分数,存在近似误差。 |
| Softmax | online softmax,等价于标准 softmax。 | attention kernel 内也维护 online softmax,但输入分数来自量化 QK。 |
| PV | FP16/BF16;FA3/FA4 背景中还涉及 FP8 路线。 | SM80/86 为 FP16 V;SM89/90/120/121 为 FP8 V;SA3 独立包为 FP4 packed Q/K/V。 |
| 误差来源 | 主要来自浮点舍入,与标准实现目标等价。 | Q/K/V 低比特量化、scale 粒度、累加器精度和 smoothing 策略。 |
| 维度 | FlashAttention | SageAttention |
|---|---|---|
| 前向显存 | O(N),不保存 N×N S/P。 | 同样不应物化 N×N S/P,但会新增 INT8/FP8 tensor 与 scale 缓冲。 |
| 训练显存 | 保存 LSE 等状态,反向按块重算 P。 | 当前主包源码没有 backward API,不能作为训练 attention 直接替换。 |
| 端到端开销 | 主要看 kernel 调度、mask/KV cache 支持和架构适配。 | 需要把量化、smooth_k、V 转置/padding/permute 的额外开销算进去。 |
| 功能 | FlashAttention / FA2/FA3 | FA4 背景 | 当前 SageAttention 源码事实 |
|---|---|---|---|
| Backward / 训练 | 支持 | 论文/生态目标支持 | 主包未暴露 backward |
| Causal | 支持 | 支持背景 | 支持,但 Triton 路径中 causal 与 attn_mask 互斥。 |
| Variable Length | 支持 | 支持背景 | sageattn_varlen,走 Triton INT8+FP16,不是 FP8 CUDA 路径。 |
| GQA/MQA | 支持 | 支持背景 | 支持约束:num_qo_heads 必须可整除 num_kv_heads。 |
| Dropout | 训练内核常见支持 | 看具体版本 | 主包 API 无 dropout 参数 |
| Sliding Window / Paged KV | 推理生态成熟 | 重要方向 | 主包未见主线 API |
| Attention Mask | 支持 | score/mask mod 背景 | 仅 Triton FP16 显式路径有通用 attn_mask |
| torch.compile | 生态支持 | custom op 背景 | README 声明 non-cudagraphs 支持,源码里 custom_op 注册 fake impl。 |
训练/推理、模型类型、硬件平台判断;SageAttention 的能力边界以本地源码为准,FlashAttention/FA4 作为外部生态背景。
快速判断:训练、复杂 mask、paged KV、decode 为主、MLA/非标准 attention,优先 FlashAttention 或框架专用内核; 标准 MHA/GQA、head_dim ≤ 128、视觉/视频生成或长 prompt prefill、可接受低比特量化误差,SageAttention 值得压测。
视觉/视频生成、DiT/ViT 类大矩阵 attention、长序列 prefill、对轻微量化误差可容忍的推理。README 的示例主要围绕 CogVideoX、Mochi、HunyuanVideo、LTX 等视频/图像模型,并提示不是所有模型都适合直接 monkey patch F.scaled_dot_product_attention。
训练或需要 backward、复杂 KV cache、LLM decode、sliding window、paged KV、dropout、精度极敏感任务。当前主包源码没有暴露 backward kernel,也没有主线 paged KV cache API;这些场景不要仅凭 forward 替换做结论。
| 场景 | 推荐 | 原因 |
|---|---|---|
| 训练 | FlashAttention / 框架专用 kernel | 需要 backward、dropout、复杂 mask 和稳定数值;当前 SageAttention 主包源码只暴露前向推理路径。 |
| 推理且精度敏感 | 先用 FlashAttention 做基线,再测试 SageAttention | SageAttention 有 INT8/FP8/FP4 量化误差,必须用任务指标确认。 |
| 推理且视觉/视频生成 | SageAttention 值得优先测试 | 原仓库示例集中在 Diffusion/视频生成;这类任务通常标准 attention 多、计算密集、对微小 attention 误差较宽容。 |
| 模型/阶段 | SageAttention 判断 | 说明 |
|---|---|---|
| Diffusion / DiT / 视频生成 | 适合 | 标准 MHA/GQA、head_dim 常为 64/128、推理为主,仓库 README 和 examples 也主要覆盖这类模型。 |
| Vision Transformer | 适合压测 | 分类/检测/分割任务通常可用任务指标评估轻微量化误差,前向推理收益可能较好。 |
| LLM Prefill | 可测试 | Prefill 是大矩阵 attention,compute-bound 特征更明显;但自回归模型的输出质量需要仔细评估。 |
| LLM Decode | 收益有限 | Q 只有少量 token,瓶颈常在 KV 读取和调度;FlashAttention/推理框架的 split-KV、paged KV 通常更成熟。 |
| MLA / DeepSeek 类 absorbed attention | 不建议直接套 | 结构非标准,等效 head_dim/数据流可能不满足 SageAttention 主包限制;应使用 vLLM/SGLang 等专用 MLA kernel。 |
| 原生 FP8 模型 | 看目标 | FlashAttention3-FP8 避免反复转换;SageAttention 需要回到 FP16/BF16 输入再量化成 INT8/FP8,可能有额外开销。 |
| 硬件 | 复习结论 | 准确性限定 |
|---|---|---|
| A100 / SM80 | 训练用 FA2/框架内核;推理可测 SageAttention CUDA INT8+FP16。 | SageAttention 自动分支为 SM80 FP16 PV,默认 FP32 累加。 |
| RTX 3090 / SM86 | SageAttention 自动走 Triton INT8+FP16。 | 没有 FP8 PV 自动路径,功能相对简单。 |
| RTX 4090 / L20 / SM89 | SageAttention2++ 的典型强项:INT8 QK + FP8 PV + fp32+fp16。 |
需要 CUDA 12.4+ 构建 FP8 支持。 |
| H100/H20 / SM90 | Diffusion 推理可比较 SageAttention 与 FA3/FA3-FP8;训练和复杂 LLM 能力优先 FA3/FA4/框架。 | SageAttention SM90 只建议 fp32+fp32 路径;普通 fp32 未实现。 |
| B200/GB200 / SM100 | 倾向 FA4;本地 SageAttention 主包没有 SM100 自动分派。 | 主包 setup 有 SM100 编译标志,但公开 sageattn() 进入 unknown arch 抛错;SA3 本地 C++ 入口也只放行 SM120/121。 |
| RTX 5090 / SM120/121 | SageAttention 主包自动走 INT8+FP8 per_warp;SA3 FP4 包也面向 SM120/121。 | 需要 CUDA 12.8+;具体端到端收益仍要用模型实测。 |
是否训练?
是 -> FlashAttention / 框架专用 kernel
否 -> 继续
是否 B200/SM100?
是 -> FA4 或框架内核;当前本地 SageAttention 主包无 sageattn() 自动路径
否 -> 继续
是否标准 MHA/GQA 且 head_dim <= 128?
否 -> FlashAttention / 模型专用内核
是 -> 继续
是否视觉/视频生成、ViT、长 prompt prefill?
是 -> SageAttention 值得优先压测
否 -> 继续
是否 decode、paged KV、sliding window、block-sparse、自定义 score/mask?
是 -> FlashAttention / FA4 / 推理框架内核
否 -> 用 FlashAttention 做基线,再比较 SageAttention 端到端质量和吞吐
参数转发:sageattn() 接收 **kwargs 但不传下去;需要改 pv_accum_dtype 或 smooth_v 时,请调用显式后端函数。
layout:HND 与 NHD 只改变 shape 解释,不会自动重排输入;传错 layout 会让 head/seq 维度完全错位。
head_dim:源码只支持 padding 到 64/128;原始 head_dim 大于 128 会直接抛错。
SM100/B200:主包 setup 有 SM100 编译标志,但 sageattn() 没有 SM100 分派;Blackwell FP4 包的 C++ 入口又检查 SM120/121。
这是一处“能编译部分标志”和“公开入口可运行”并不等价的边界。
变长 smooth_k:sageattn_varlen() 对 K mean 的注释说明它按全部 batch token 计算,不是每个序列各自计算。
causal mask:Triton FP16 路径里 causal 与 attn_mask 互斥;CUDA 显式路径没有通用 attn_mask 参数。
按这个顺序读,能最快把“API 怎么进来、数据怎么变、kernel 怎么发射”串起来。
sageattention/__init__.py:1 看包对外暴露哪些 API。
core.py:79 读 sageattn() 的参数、文档和 SM 分支。
core.py:160 看 SM86/Triton 的量化、mask 和 return_lse。
core.py:334 看 sageattn_varlen() 的 cu_seqlens 与 per-block 量化。
core.py:451 看 FP16 PV 的三种累加选择。
core.py:636 看 FP8 V 量化与两级累加。
core.py:829 看 Hopper WGMMA/TMA 后端入口。
quant.py:22 到 quant.py:224 串起 INT8/FP8 准备。
csrc/fused/pybind.cpp:21 看 Python 调用的 fused CUDA 符号。
qk_int_sv_f8_cuda_sm90.cu:126 看 TMA、scale index、online softmax 状态。
setup.py:129 看 CUDA 版本和架构编译条件。
sageattn3/api.py:75 看 Blackwell FP4 预处理和量化调用。
这一节专门把容易随时间变化、或和当前源码边界不同的表述校准成复习时更安全的说法。
| 主题 | 复习时的安全表述 | 依据/边界 |
|---|---|---|
| SA2/SA3 与 SM100/B200 | 当前本地 SageAttention 主包 sageattn() 没有 SM100 自动分派;SA3 Blackwell 包当前 C++ 入口检查 SM120/121。 |
主包 core.py:143 分派缺少 SM100;SA3 api.cu:219 检查 SM120/121。 |
| “SageAttention 无 backward” | 对当前主包源码可以这样说:公开 API 和 pybind 主要是 forward attention,没有训练 backward 替换路径。 | 不要把它外推成论文永远不研究训练;SA3 论文标题包含 8-bit training exploration,但本地公开入口仍是推理前向。 |
| FA4 版本/PR/性能数字 | 作为论文和生态背景复习,不作为本 SageAttention 仓库事实。实际项目中要重新查上游 release、README、benchmark。 | 这类信息高度时间敏感,本文只保留技术方向:Blackwell、TMEM/UMMA、CuTe DSL、FlexAttention、block-sparse。 |
| H100 上 SageAttention vs FA3-FP8 | 可作为 Diffusion 推理候选,但必须用目标模型端到端质量和吞吐实测。 | SageAttention README 的示例和 TOPS 图强调 kernel 速度,且 TOPS 不含量化/smoothing 开销。 |
| LLM Prefill/Decode | Prefill 更可能从低比特矩阵乘受益;Decode 常受 KV 读、调度和 cache 机制限制,FlashAttention/推理框架更常见。 | 这是工程推断,不是本 SageAttention 仓库提供的 LLM decode 专用 API。 |
| 直接替换 SDPA | 可以快速试验,但不应当作为所有模型的最终集成方式。 | README 明确提示并非所有模型都适合 F.scaled_dot_product_attention = sageattn,图像/视频模型建议改 Attention class。 |
最后的趋势可以总结为:IO-aware、低比特量化、可编程稀疏和推理框架动态调度会继续融合。
FA4 代表的是架构深度绑定的 exact attention 路线:利用 Blackwell 的异步 MMA、TMEM、CuTe DSL 和 FlexAttention, 在保持数值形式更稳定的同时扩展 block-sparse、score/mask modifier 和训练能力。
SageAttention2/2++ 证明了 INT8 QK + FP8 PV 在推理中可以获得很强吞吐;SageAttention3 进一步探索 FP4 和训练相关问题。 但当前本地源码中,SA3 是独立 Blackwell FP4 包,主包和 SA3 的硬件入口边界要分开记。
对 vLLM、SGLang 这类引擎,未来更现实的路线不是单一 attention kernel 通吃,而是按 GPU、模型结构、prefill/decode 阶段、KV cache 和精度目标动态选择后端。
B200 / 训练 / 复杂 mask -> FlashAttention / FA4 / 框架内核
视觉视频推理 / 标准 MHA/GQA -> SageAttention 值得优先压测
LLM prefill -> 两者都测,关注质量和端到端吞吐
LLM decode / paged KV -> FlashAttention 或推理框架专用内核
SM100 上当前本地 SageAttention -> 不要假设 sageattn() 可直接跑
参考文献列表,并补充源码入口。外部论文链接用于理解技术路线,源码链接用于确认当前实现。
Dao, “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,” ICLR 2024。
读完源码后,可以用这些问题自测是否真正抓住实现骨架。
sageattn() 在 SM80/86/89/90/120/121 分别进入哪个后端。smooth_k 的均值减法在哪里计算、在哪里融合、为什么返回 LSE 时要校正。per_block、per_warp、per_thread 的 scale 形状和工程取舍。pv_accum_dtype="fp32+fp16" 为什么把 V FP8 的 scale_max 降到 2.25。sageattention3_blackwell FP4 包的边界。scaled_dot_product_attention,以及何时需要改 Attention class。