搞懂flash_attention
本文记录了学习flash_attention遇到的一些好的文章,帮助你搞懂flash_attention。
我们知道现在的LLM大模型主流是基于attention搭建的,attention的计算效率也决定了生产场景中大模型的可用性。flash_attention目前有三个版本,分别是flash_attention和flash_attention2和flash_attention3,它们的目的都是采取一系列的优化手段,提高attention的计算效率。
根据Flash Attention原理详解,flashattention的核心思想是减少HBM的访问,将QKV切分为小块后放入SRAM中。FlashAttention 优化了显存存取,要搞懂flashattension就要搞懂softmax的优化计算,手撕online softmax, Flash Attention前传,一撕一个不吱声,手撕LLM-Flash Attention从softmax说起。
X_batch = torch.randn(4, 6)
_, d = X_batch.shape
X_batch_block_0 = X_batch[:, :d//2]
X_batch_block_1 = X_batch[:, d//2:]
# we parallel calculate different block max & sum
X_batch_0_max, _ = X_batch_block_0.max(dim = 1, keepdim = True)
X_batch_0_sum = torch.exp(X_batch_block_0 - X_batch_0_max).sum(dim = 1, keepdim = True)
X_batch_1_max, _ = X_batch_block_1.max(dim = 1, keepdim = True)
X_batch_1_sum = torch.exp(X_batch_block_1 - X_batch_1_max).sum(dim = 1, keepdim = True)
# online batch block update max & sum
X_batch_1_max_update = torch.maximum(X_batch_0_max, X_batch_1_max) # 逐个元素找最大值
X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \
+ torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum
X_batch_online_softmax = torch.exp(X_batch - X_batch_1_max_update) / X_batch_1_sum_update
print(X_batch_online_softmax)
下面这句为啥和公式不一样呢?实际上是加的下面图中的两个红框。
X_batch_1_sum_update = X_batch_0_sum * torch.exp(X_batch_0_max - X_batch_1_max_update) \
+ torch.exp(X_batch_block_1 - X_batch_1_max_update).sum(dim = 1, keepdim = True) # block sum
flash_attention2¶
手撕LLM-FlashAttention2只因For循环优化的太美,Flash Attention 2比Flash Attention 1加速2x, 计算效率达到GEMM性能的50~73%。
- 减少非乘法计算
- 优化QKV for循环顺序
- 采用shared memory减少通信
flash_attention3¶
gpu-mode-cutlass-and-flashattention-3
实现Softmax¶
为了使用CUDA实现Softmax并通过PyTorch调用,需要编写CUDA核函数和C++包装器, 创建两个文件:main.cpp和softmax.cu。然后,在Python中加载并调用CUDA扩展。
详细代码实现参考KenForever1/online_softmax。
实现FlashAttention¶
tspeterkim/flash-attention-minimal。