## The Attention Block [Step2c: The KV Cache; Step2d: Group Query Attention]
## As mentioned before, the naming convention follows original the meta's LLama3 GitHub
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# Embedding dimension
self.dim = args.dim
# Number of heads assigned to Query
self.n_heads = args.n_heads
# Number of heads assigned to Key and values. If "None", the number will be same as Query.
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# Dimension of each head relative to model dimension
self.head_dim = args.dim // args.n_heads
# Number of repetition in order to make time Key, Value heads to match Query heads number
self.n_rep = args.n_heads // args.n_kv_heads
# Weight initialize for Keys, Querys, Values and Oupt. Notice that the out_feature value of weight for q and kv are based on it's heads
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)
# Initialize caches to store Key, Values at start. (KV Cache Implementation)
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
def forward(self, x: torch.Tensor, start_pos, inference):
# Shape of the input embedding: [bsz,seq_len,dim]
bsz, seq_len, _ = x.shape
# Mask will be used during 'Training' and is not required for 'inference' due to the use of KV cache.
mask = None
xq = self.wq(x) #x[bsz,seq_len,dim]*wq[dim,n_heads * head_dim] -> q[bsz,seq_len,n_heads * head_dim]
xk = self.wk(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> k[bsz,seq_len,n_kv_heads * head_dim]
xv = self.wv(x) #x[bsz,seq_len,dim]*wq[dim,n_kv_heads * head_dim] -> v[bsz,seq_len,n_kv_heads * head_dim]
# Reshaping Querys, Keys and Values by their number of heads. (Group Query Attention Implementation)
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim) #xq[bsz,seq_len,n_heads, head_dim]
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xk[bsz,seq_len,n_kv_heads, head_dim]
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim) #xv[bsz,seq_len,n_kv_heads, head_dim]
# Model - Inference Mode: kv-cache is enabled at inference mode only.
if inference:
# Compute rotation matrix for each position in the sequence
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2)
# During inferencing, we should only take the rotation matrix range from the current position of the tokens.
freqs_cis = freqs_cis[start_pos : start_pos + seq_len]
# Apply RoPE to Queries and Keys embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
# Store Keys and Values token embedding into their respective cache [KV Cache Implementation]
self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk
self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv
# Assign all the previous tokens embeddings upto current tokens position to Keys and Values variable for Attention Calculation
keys = self.cache_k[:bsz, :start_pos + seq_len]
values = self.cache_v[:bsz, :start_pos + seq_len]
# At this point, they Keys and Values shape aren't same with Queries Embedding which has to be in order to computer attention score
# Use repeat_kv function to make Keys,Values shape same as queries shape
keys = repeat_kv(keys, self.n_rep) #keys[bsz,seq_len,n_heads,head_dim]
values = repeat_kv(values, self.n_rep) #values[bsz,seq_len,n_heads,head_dim]
# Mode - Training mode: KV-Cache not implemented
else:
# Compute rotation matrix and apply RoPE to queries and keys for for training.
freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len)
#xq[bsz,seq_len,n_heads, head_dim], xk[bsz,seq_len,n_heads, head_dim]
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# Use repeat_kv function to make Keys,Values shape same as the queries shape
#keys[bsz,seq_len,n_heads,head_dim], #values[bsz,seq_len,n_heads,head_dim]
keys = repeat_kv(xk, self.n_rep)
values = repeat_kv(xv, self.n_rep)
# For training mode, we'll compute mask and apply to the attention score later
mask = torch.full((seq_len, seq_len),float("-inf"),device=self.args.device)
mask = torch.triu(mask, diagonal=1).to(self.args.device)
# To compute attention, we'll need to perform a transpose operation to reshape all queries, keys and values bring heads at dim 1 and seq at dim 2
xq = xq.transpose(1,2) #xq[bsz,n_heads,seq_len,head_dim]
keys = keys.transpose(1,2) #keys[bsz,n_heads,seq_len,head_dim]
values = values.transpose(1,2) #values[bsz,n_heads,seq_len,head_dim]
# Computing attention score
scores = torch.matmul(xq, keys.transpose(2,3)).to(self.args.device)/math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask
# Apply softmax to the attention score
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# Matrix multiplication of attention score with the values
output = torch.matmul(scores, values).to(self.args.device)
# We get the contextual embedding for each head
# All heads need to be reshaped back and combined to give a single single contextual attention output
# Shape change: output[bsz,n_heads,seq_len,head_dim] -> output[bsz,seq_len, n_heads,head_dim] -> output[bsz,seq_len, n_heads * head_dim]
output = output.transpose(1,2).contiguous().view(bsz, seq_len, -1)
# shape: output [bsz,seq_len,dim]
return self.wo(output)
# If the number of keys/values heads is less than query heads, this function expands the key/values embeddings with the required number of repetition
def repeat_kv(x:torch.Tensor, n_rep: int)-> torch.Tensor:
bsz, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:,:,:,None,:]
.expand(bsz,seq_len,n_kv_heads,n_rep, head_dim)
.reshape(bsz,seq_len,n_kv_heads * n_rep, head_dim)
)
### Test: Repeat_kv function ###
# note: xk, x_norm is already calculated during RoPE, RMSNorm testing and is being used for testing here.
# You need take out the triple quotes below to perform testing
"""
n_rep = ModelArgs.n_heads // ModelArgs.n_kv_heads
keys = repeat_kv(xk, n_rep)
print(f"xk.shape: {xk.shape}")
print(f"keys.shape: {keys.shape}")
## Test: Attention function
# You need take out the triple quotes below to perform testing
attention = Attention(ModelArgs)
x_out = attention(x_norm,start_pos=0, inference=False)
print(f"x_out.shape: {x_out.shape}")
"""
### Test Results: ###
"""
xk.shape: torch.Size([10, 256, 4, 64])
keys.shape: torch.Size([10, 256, 8, 64])
x_out.shape: torch.Size([10, 256, 512])
"""