-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattention.py
More file actions
120 lines (100 loc) · 5.44 KB
/
attention.py
File metadata and controls
120 lines (100 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import math
import torch
from torch import nn
# Helper function to rotate the last half of a tensor
# Used in rotary positional embeddings to compute sine and cosine rotations
def rotate_half(x):
# Split tensor into two halves along the last dimension
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
# Rotate: negate the second half and concatenate it with the first half
return torch.cat((-x2, x1), dim=-1)
# Applies rotary positional embeddings to query (q) and key (k) tensors
# Uses sine and cosine positional encodings to enhance positional awareness
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# Expand cos and sin tensors for broadcasting
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Apply rotations to q and k using cos and sin
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Repeats key-value tensors for multiple attention heads
# Ensures compatibility between the number of attention heads and key-value heads
def repeat_kv(hidden_states, n_rep):
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
# Expand the number of key-value heads by repeating them
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
# Reshape to align with the expected multi-head attention format
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Computes rotary positional embeddings for queries and keys
class RotaryEmbedder(nn.Module):
def __init__(self, dim, base):
super().__init__()
# Precompute frequency for sine/cosine embeddings
self.freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
@torch.no_grad()
def forward(self, x):
# Generate positions (sequence indices) for the input
pos = torch.arange(x.shape[-2], dtype=torch.long)
# Compute angles for sine and cosine embeddings
angles = torch.einsum("p,f->pf", pos.float(), self.freq).unsqueeze(dim=0)
# Duplicate angles for sine and cosine embeddings
emb = torch.cat((angles, angles), dim=-1)
# Return cosine and sine components of the positional embeddings
return emb.cos(), emb.sin()
# Implements attention with rotary positional embeddings
class RopeAttention(nn.Module):
def __init__(self, config):
super().__init__()
# Model dimensions and attention configurations
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // self.num_heads
self.kv_heads = config.kv_heads # Number of key-value heads
self.rope_theta = 10000.0 # Scaling factor for rotary embeddings
# Linear projections for queries, keys, values, and output
self.W_query = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.W_key = nn.Linear(config.hidden_size, self.kv_heads * self.head_dim, bias=False)
self.W_value = nn.Linear(config.hidden_size, self.kv_heads * self.head_dim, bias=False)
self.W_output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# Rotary embedding generator
self.rotary_emb = RotaryEmbedder(base=self.rope_theta, dim=self.head_dim)
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
# Input dimensions: (batch_size, seq_len, hidden_size)
b, q, _ = hidden_states.size()
# Project input hidden states into queries, keys, and values
q_states = self.W_query(hidden_states)
k_states = self.W_key(hidden_states)
v_states = self.W_value(hidden_states)
# Reshape and transpose for multi-head attention
q_states = q_states.view(b, q, self.num_heads, self.head_dim).transpose(1, 2)
k_states = k_states.view(b, q, self.kv_heads, self.head_dim).transpose(1, 2)
v_states = v_states.view(b, q, self.kv_heads, self.head_dim).transpose(1, 2)
# Compute rotary positional embeddings
cos, sin = self.rotary_emb(q_states)
# Apply positional embeddings to queries and keys
q_states, k_states = apply_rotary_pos_emb(q_states, k_states, cos, sin)
# Repeat key and value tensors to match the number of query heads
__kv_groups = self.num_heads // self.kv_heads
k_states = repeat_kv(k_states, __kv_groups)
v_states = repeat_kv(v_states, __kv_groups)
# Compute attention scores (scaled dot-product attention)
attn_weights = torch.matmul(q_states, k_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# Add attention mask (e.g., for causal or padding masking)
attn_weights = attn_weights + attention_mask
# Normalize attention weights using softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
# Apply dropout to attention weights
attn_weights = nn.functional.dropout(attn_weights, 0)
# Compute attention output
attn_output = torch.matmul(attn_weights, v_states)
# Reshape and transpose back to original format
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(b, q, -1)
# Project the attention output back to the hidden size
attn_output = self.W_output(attn_output)
# Return the final attention output
return attn_output