跳转至

如何实现paged_attention--基于flash-attention的PagedAttention内核实现缓存管理器

类似linux操作系统管理内存的机制,paged_attention用于管理LLM推理时kv cache的显存分配,通过页表机制,优化显存分配,减少碎片。

paged_attention介绍

传统上,请求的键值缓存有以下两点:

  • 存储在连续的内存空间中;
  • 预先分配最大上下文长度的内存(对于 Llama3 为 8192)。

这会导致严重的内存碎片,例如,如果一个请求的实际长度被生成为 792 个标记,那么大约 90%(=7400/8192)的预分配内存会被碎片化,即无法被其他任何请求使用。

为了减少内存碎片并提高请求吞吐量(批量大小),分页注意力(PagedAttention)提供了一种非连续的键值缓存内存管理方案,大致遵循操作系统分页。这确保了内存碎片仅在每个请求的最后分配块中发生:在下面的图表中,用红色勾勒出的部分,请求 A 在物理块 3 中有 3 个tokens,请求 B 在物理块 2 中有 2 个tokens。

从代码上看attention和paged_attention的区别:

# attention
y = attn(k_cache=k_cache, v_cache=v_cache, ...)
# paged_attention
y = paged_attn(k_cache=k_cache_paged, v_cache=v_cache_paged, block_table=block_table, ...)

与k_cache不同,k_cache_paged是非连续的,并且由所有请求共享。物理块 0~8 可以分配给任何请求,这就是为什么我们传入block_table,它包含每个请求对逻辑块到物理块的分配。例如,在上面的图表中,block_table看起来像{0: [7,1,3], 1: [5,2]}(0 和 1 分别是请求 A 和 B 的索引)。

基于flash-attention实现缓存管理器

万丈高楼拔地起,我们可以基于现有的基础架构,比如基于flash-attention的PagedAttention内核实现缓存管理器,也可以从零开始搭建。

今天介绍的是基于Dao-AILab/flash-attention,它采用了flash-attention的PagedAttention内核实现。用户只需要实现缓存管理器。它与缓存管理器一起使用(例如在 vLLM 中),该缓存管理器管理何时分配和释放块以及构建块表。缓存管理器的实现取决于你如何构建推理引擎,因此flash-attention没有实现这样的缓存管理器。

用户实现实现一个缓存管理器,该缓存管理器管理何时分配和释放块以及构建块表,也就是下面代码中的block_table,然后传递block_table给flash_attention。

from flash_attn import flash_attn_with_kvcache

y = flash_attn_with_kvcache(q, k_cache_paged, v_cache_paged, k, v, cache_seqlens=cache_seqlens, block_table=block_table, causal=True)

flash_attn_with_kvcache介绍

def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    cache_leftpad: Optional[torch.Tensor] = None,
    block_table: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    softcap=0.0, # 0.0 means deactivated
    rotary_interleaved=True,
    alibi_slopes=None,
    num_splits=0,
    return_softmax_lse=False,
):

如果 k 和 v 不为 None,k_cache 和 v_cache 将被原地更新为来自 k 和 v 的新值。这对于decoding很有用:你可以传入上一步的缓存键/值,并使用当前步的新键/值进行更新,然后使用更新后的缓存进行注意力计算,所有这些都在一个内核中完成。

如果你传入 k / v,你必须确保缓存足够大以容纳新值。例如,KV 缓存可以预先分配最大序列长度(max_seq_len),并且你可以使用 cache_seqlens 来跟踪批处理中每个序列的当前序列长度。

如果你想详细了解flash_attn_with_kvcache,请参考flash-attention接口说明

实现CacheManager

实现详细代码参考: tspeterkim/paged-attention-minimal

实现CacheManager类,用于管理缓存。Flash Attention目前支持的块大小为256。

block_size = 256
class CacheManager:
    def __init__(self, tokens, block_size=block_size, batch_size=bsz, n_kv_heads=n_kv_heads, head_dim=head_dim):
        self.block_size = block_size
        self.batch_size = bsz
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        self.num_blocks = (max_seq_len // block_size) * 5 # TODO: make this dynamic
        # [batch_id, (index, filled_positions)]
        self.block_table = {i: [] for i in range(batch_size)}
        self.free_blocks = set(range(self.num_blocks))

        self.k_cache_paged = torch.randn(self.num_blocks, block_size, n_kv_heads, head_dim, device=device, dtype=torch.bfloat16)
        self.v_cache_paged = torch.randn(self.num_blocks, block_size, n_kv_heads, head_dim, device=device, dtype=torch.bfloat16)

        seq_lens = (tokens != -1).sum(1)
        for i, t in enumerate(seq_lens.tolist()): 
            num_blocks_to_reserve = math.ceil(t / block_size)
            num_filled_positions = t % block_size
            for b in range(num_blocks_to_reserve):
                index = self.get_free_block()
                if b == num_blocks_to_reserve-1:
                    self.block_table[i].append((index, num_filled_positions))
                else:
                    self.block_table[i].append((index, block_size))

    # Returns a free block to allocate more tokens to.
    # For simplicity, I raise an error when we run out of free blocks.
    # In the actual implementation, it solves this through scheduling and preemption (see paper)
    def get_free_block(self):
        if len(self.free_blocks) == 0:
            raise Exception('No more free blocks. Implement scheduling and preemption.')
        index = random.choice(list(self.free_blocks))
        self.free_blocks.remove(index)
        return index

    # Gets the logical block table that PagedAttention uses
    # TODO: Serial computation makes it slow. Is there a faster way?
    # 将block_table转换为tensor,用于PagedAttention的输入
    def get_block_table(self):
        max_len = max(len(b) for b in self.block_table.values())
        block_table = [[-1] * max_len for _ in range(self.batch_size)]
        # i is batch index, j is block index
        for i, b in self.block_table.items():
            for j, (index, _) in enumerate(b):
                block_table[i][j] = index
        return torch.tensor(block_table, dtype=torch.int32, device=device)

    def get_kv_cache(self):
        return self.k_cache_paged, self.v_cache_paged

    # Specific to my KV implementation. Returns the last sequence position given the block table.
    def get_last_pos(self):
        last_pos = [(len(b)-1)*self.block_size + b[len(b)-1][1]-1 for b in self.block_table.values()]
        return torch.tensor(last_pos, dtype=torch.int32, device=device)

    # Frees request's blocks.
    # Here, I leave one block, and free the rest. This is a limitation imposed by my kv cache implementation.
    # TODO: Avoid this limitation.
    def free_memory(self, index):
        blocks = self.block_table[index]
        if len(blocks) == 1:
            return
        for i, _ in blocks[1:]:
            self.free_blocks.add(i)
        self.block_table[index] = blocks[:1]

    # Updates block table and filled positions.
    # TODO: Again, pretty slow. Faster parallel way?
    def update(self, eos_reached, input_text_mask):
        for i, (eos, is_prompt) in enumerate(zip(eos_reached, input_text_mask)):
            if is_prompt: # if the token is part of the original prompt, we skip
                continue
            if eos: # free the request's blocks since we have generated the complete answer
                self.free_memory(i)
                continue

            old_index, n = self.block_table[i][-1]
            if n == self.block_size: # allocate new block if necessary
                new_index = self.get_free_block()
                self.block_table[i].append((new_index, 1))
            else: # otherwise, just use the next available slot in the block
                self.block_table[i][-1] = (old_index, n+1)

    def get_fragmented_memory_size(self):
        size = 0
        for b in self.block_table.values():
            _, filled = b[-1] # only the last block has fragmentation
            size += (self.block_size - filled) * n_kv_heads * head_dim * 2 * 2
        return size

# Create CacheManagers for each layer
# 为每层创建CacheManager
cms = [CacheManager(tokens) for _ in range(n_layers)]

如何使用这个缓存管理器来执行paged_attention操作?

在forward函数中,我们需要在每个层上执行paged_attention操作。

def forward(tokens, start_pos):
    bsz, T = tokens.shape
    final_embedding = embedding_layer(tokens)
    freqs_cis = freqs_cis_max[start_pos:start_pos+T, :]

    for layer in range(n_layers):
        q_layer = model[f'layers.{layer}.attention.wq.weight']
        k_layer = model[f'layers.{layer}.attention.wk.weight']
        v_layer = model[f'layers.{layer}.attention.wv.weight']
        w_layer = model[f'layers.{layer}.attention.wo.weight']
        ......

        # 调用该层的CacheManager以获取block_table和kv_cache
        block_table = cms[layer].get_block_table()
        # print k_cache_paged.shape: torch.Size([160, 256, 8, 64]), 为(self.num_blocks, block_size, n_kv_heads, head_dim)
        k_cache_paged, v_cache_paged = cms[layer].get_kv_cache()
        cache_seqlens = torch.where(eos_reached, cms[layer].get_last_pos(), torch.tensor([start_pos]*bsz, dtype=torch.int32, device=device))
        # 执行paged_attention
        y = flash_attn_with_kvcache(q, k_cache_paged, v_cache_paged, k, v, cache_seqlens=cache_seqlens, block_table=block_table, causal=True)

        # (Pdb) p tokens.shape
        # torch.Size([1, 44])
        # (Pdb) p y.shape
        # torch.Size([1, 44, 32, 64])
        # (Pdb) p q.shape
        # torch.Size([1, 44, 32, 64])
        # (Pdb) p k.shape
        # torch.Size([1, 44, 8, 64])
        # (Pdb) p v.shape
        # torch.Size([1, 44, 8, 64])

        # 从paged_attention中获取的结果, 输入到下一层计算
        stacked_qkv_attention = y.view(bsz, T, dim)

        # (Pdb) p stacked_qkv_attention.shape
        # torch.Size([1, 44, 2048])

        embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
        ......

在decode过程中,调用forward函数,并且调用cms[layer].update(eos_reached, input_text_mask)来更新CacheManager的block_table和free_memory。

# Do inference
for cur_pos in range(min_prompt_len, max_seq_len):
    next_token = forward(tokens[:,prev_pos:cur_pos], prev_pos)
    next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
    tokens[:, cur_pos] = next_token

    pdb.set_trace()

    # Update CacheManagers. Increment filled positions + allocate new block if required.
    for layer in range(n_layers):
        cms[layer].update(eos_reached.tolist(), input_text_mask[:, cur_pos].tolist())

    eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
    prev_pos = cur_pos

    if all(eos_reached):
        break

Note

如果你想调试python文件,可以通过在代码中添加import pdb; pdb.set_trace()来设置断点。然后执行python文件时,程序会暂停在断点处。

参考

https://github.com/tspeterkim/paged-attention-minimal?tab=readme-ov-file

评论