diff --git a/src/gpt/attention.py b/src/gpt/attention.py index 2c3d537..0a73e9a 100644 --- a/src/gpt/attention.py +++ b/src/gpt/attention.py @@ -19,13 +19,18 @@ def __init__( self.key = nn.Linear(src_embed_dim, head_size, bias=False).to(self.device) # what am I? [b, c, h] self.value = nn.Linear(src_embed_dim, head_size, bias=False).to(self.device) # what can I tell you about me? self.dropout = nn.Dropout(dropout_p) + self.max_c = context_length # max context length # don't optimize the tril, that's only here for masking self.register_buffer("tril", torch.tril(torch.ones(context_length, context_length)).to(self.device)) def forward(self, x): - _, context, embed = x.shape + batch, context, embed = x.shape + if self.cache is None: + self.cache = torch.empty(batch, self.max_c, self.max_c, device=self.device) # kv cache k, q, v = self.key(x), self.query(x), self.value(x) + # calculate only newest values + weights = q @ k.transpose(-2, -1) # [b, c, h] @ [b, h, c] -> [b, c, c] weights = weights / embed ** (-0.5) # preserve variance of weights weights = weights.masked_fill(self.tril[:context, :context] == 0, float("-inf")) # only in decoder blocks diff --git a/tests/test_unit.py b/tests/test_unit.py index 39aad37..8eb9c98 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -1,7 +1,9 @@ import torch import pytest - +from torch.utils.data import DataLoader +from gpt.data import TextDataset from gpt.attention import MultiHeadAttention +from gpt.model import LM @pytest.fixture(params=[True, False]) @@ -27,3 +29,25 @@ def test_multihead_attention_forward(multihead_attention): x = torch.rand(1, 10, 256) # batch_size=1, context_length=10, src_embed_dim=256 output = multihead_attention(x) assert output.shape == (1, 10, 256) # batch_size=1, context_length=10, src_embed_dim=256 + + +def test_generation(): + text = open("data/tiny-shakespeare.txt").read() + train_dataset = TextDataset(text, device=torch.device("cpu"), context_length=64, batch_size=4) + + model = LM( + train_dataset.vocab_size, + context_length=train_dataset.context_length, + embed_dim=128, + num_layers=4, + num_heads=4, + dropout_p=0.2, + rope=False, + device=torch.device("cpu"), + ) + + prompt = "ROMEO:" + prompt_encoded = train_dataset.encode_batch([prompt]) + generated = model.generate(prompt_encoded, max_len=100) + generated_text = train_dataset.decode_batch(generated)[0] + assert len(generated_text) == 100