Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions book/di-10-jie-dai-ma-shi-xian-zhong-de-zhong-dian-xuan-du.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ class TutorialLLM(nn.Module):
self.device = device
self.token_embedding_table = nn.Embedding(vocabulary_size, dim_embed)
self.position_embedding_table = nn.Embedding(max_length, dim_embed)
self.transformer_blocks = nn.Sequential(*[TranformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)])
self.transformer_blocks = nn.Sequential(*[TransformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)])
self.layer_norm_final = nn.LayerNorm(dim_embed)
self.project = nn.Linear(dim_embed, vocabulary_size)

def forward(self, token_ids: Tensor, labels: Tensor = None, reduce_loss: bool = True) -> tuple[Tensor, Optional[Tensor]]:
B, T = token_ids.shape
token_embedding = self.token_embedding_table(token_ids) # (B, T) -> (B, T, dim_embed)
position_embedding = self.position_embedding_table(torch.arange(T, device=self.device)) # (T) -> (T, dim_embed)
position_embedding = self.position_embedding_table(torch.arange(T, device=token_ids.device)) # (T) -> (T, dim_embed)
embedding = token_embedding + position_embedding # (B, T, dim_embed) + (T, dim_embed) -> (B, T, dim_embed)
embedding = self.transformer_blocks(embedding) # (B, T, dim_embed) -> (B, T, dim_embed)
embedding = self.layer_norm_final(embedding) # (B, T, dim_embed) -> (B, T, dim_embed)
Expand All @@ -98,7 +98,8 @@ class TutorialLLM(nn.Module):
B, T, vocabulary_size = logits.shape
logits = logits.view(B * T, vocabulary_size)
labels = labels.view(B * T)
loss = F.cross_entropy(logits, labels, reduce=reduce_loss)
reduction = 'mean' if reduce_loss else 'none'
loss = F.cross_entropy(logits, labels, reduction=reduction)

return logits, loss
```
Expand All @@ -122,7 +123,7 @@ Layer Norm则是对每一层的输入数据做归一化,把输入的分布转

{% code title="model.py" %}
```python
class TranformerBlock(nn.Module):
class TransformerBlock(nn.Module):

def __init__(self, dim_embed: int, num_heads: int, max_length: int) -> None:
super().__init__()
Expand Down Expand Up @@ -180,7 +181,7 @@ class AttentionHead(nn.Module):
query = self.project_to_query(input) # (B, T, dim_embed) -> (B, T, head_size)
value = self.project_to_value(input) # (B, T, dim_embed) -> (B, T, head_size)
weights = query @ key.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
weights *= dim_embed ** -0.5
weights *= query.size(-1) ** -0.5
weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
output = weights @ value # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
Expand All @@ -196,7 +197,7 @@ class AttentionHead(nn.Module):

于是,当我们计算`query @ key.transpose(-2, -1)`时,相当于将`query`中的每个字向量与`key`中的每个字向量求相似度,从而得到T×T大小的方阵,记作`weights`。方阵中的每个元素代表了某个字对另一个字来说的重要性。

紧接着是一个前文没有提到的操作,`weights *= dim_embed ** -0.5`。它让`weights`中的所有元素统一缩小根号`dim_embed`倍,本质上也是一种归一化,作用仍然是稳定模型的训练。
紧接着是一个前文没有提到的操作,`weights *= query.size(-1) ** -0.5`。它让`weights`中的所有元素统一缩小根号`head_size`倍,本质上也是一种归一化,作用仍然是稳定模型的训练。

然后,最关键的两步来了。`weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))`利用最开始缓存的下三角矩阵将`weights`的上三角区域重置为0,意味着每个字向量只能参考前面字的信息,不能参考后面字的信息。`weights = F.softmax(weights, dim=-1)`进一步将相似度转换为概率。

Expand Down
10 changes: 6 additions & 4 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, input_path: str = 'data.json', batch_size: int = 16, max_leng
+ For instruction finetune data, each poem is formatted as an instruction-response pair.
The instruction is a fixed string '請用以下題目寫一首詩' and a title, while the response is the paragraphs of the poem.
+ For alignment data, each item contains a positive-negative pair of poems. The positive pair is the original poem,
while the negative pair has at least one paragraph replaced by a random paragraph from other poems.
while the negative pair is sampled from a random non-five-words poem.

Data in each category will be further split into train and evaluate sets.
All the data will be tokenized into a token id sequence, where each token is a character in the vocabulary.
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self, input_path: str = 'data.json', batch_size: int = 16, max_leng
print(alignment_texts[0])

# Create a vocabulary from all the characters appeared in the poems and the instructions.
# Note that we add a special character '\0' in the end, which is used as an end-of-text token.
# Note that we add a special character '\0' in the end, which is used as an end-of-text token (it will be index 0 in the vocabulary).
# An end-of-text token is useful to let the model know when to stop generating text.
all_text = f'{pretrain_text}{"".join(finetune_texts)}{"".join([pair[0] + pair[1] for pair in alignment_texts])}\0'
# Get a sorted list of unique characters
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_batch_pretrain(self, split: str) -> tuple[Tensor, Tensor]:

Returns:
Two tensors of shape (`batch_size`, `max_length`), where the first tensor is the input tokens and the second tensor is the label tokens.
The second dimension is the length of the text. We formed each label by shifting the input by one character to the right.
The second dimension is the length of the text. We formed each label by shifting the input by one character to the right to let the model learn to predict the next character.
"""
# Choose train or evaluate split
data = self.pretrain_train_data if split == 'train' else self.pretrain_evaluate_data
Expand Down Expand Up @@ -229,12 +229,14 @@ def get_batch_generator_alignment(self, split: str) -> Generator[tuple[Tensor, T
def process_batch(self, batch: list) -> tuple[Tensor, Tensor]:
"""
Process a batch of token id lists.
Pad positions beyond each item's actual length with 0, then mask those label positions by setting them to -100.
This keeps the first 0 label as the end-of-text marker and ignores the remaining padding tokens in the loss calculation.

Comment on lines 231 to 234
Args:
batch: A list of token id lists, where each list is a poem represented by token ids.

Returns:
A batch of input token id lists and label token ids. The label refer to the next character of each input sequence
A batch of input token id lists and label token ids. The label refers to the next character of each input sequence.
"""
# All the inputs and labels are initialized to zeros of largest length
inputs = torch.zeros(len(batch), self.max_length, dtype=torch.long)
Expand Down
22 changes: 12 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def forward(self, input: Tensor) -> Tensor:
value = self.project_to_value(input) # (B, T, dim_embed) -> (B, T, head_size)
# Compute the self-attention weights
weights = query @ key.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
# Scale the attention weights to
weights *= dim_embed ** -0.5
# Scale the attention weights to avoid the problem of vanishing gradients.
weights *= query.size(-1) ** -0.5
# Mask the attention weights to respect the causal constraint
# Slice the tril matrix to fit the size of the current input
weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
Expand Down Expand Up @@ -88,8 +88,9 @@ def __init__(self, dim_embed: int, num_heads: int, head_size: int, max_length: i
# Create a list of `num_heads` attention heads
self.heads = nn.ModuleList([AttentionHead(dim_embed, head_size, max_length) for _ in range(num_heads)])
# Create a linear layer to project the concatenated output of all heads to the original dimension.
# In our case, the concatenated output is happen to be the same as the original dimension, so we can skip
# this projection layer. But in general, the output of the heads may have different dimension than the input.
# In our case, the concatenated output happens to be the same as the original dimension, so we can skip
# this projection layer. But in general, the output of the heads may have different dimension than the input and
# we keep this projection layer to make sure the flow of computation is consistent.
self.project = nn.Linear(head_size * num_heads, dim_embed)

def forward(self, input: Tensor) -> Tensor:
Expand Down Expand Up @@ -148,7 +149,7 @@ def forward(self, input: Tensor) -> Tensor:
"""
return self.feed_forward(input) # (B, T, dim_embed) -> (B, T, 4 * dim_embed) -> (B, T, dim_embed)

class TranformerBlock(nn.Module):
class TransformerBlock(nn.Module):
"""
Transformer block.

Expand Down Expand Up @@ -182,8 +183,7 @@ def forward(self, input: Tensor) -> Tensor:
"""
Compute the output of the transformer block for the input tensor.

We treat the attention heads and the feed-forward neural network as residual
steams.
We treat the attention heads and the feed-forward neural network as residual streams.

Args:
input: A tensor of shape (B, T, `dim_embed`) where B is the batch size,
Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, vocabulary_size: int, dim_embed: int, max_length: int, num_he
# Create a position embedding table to add positional information to the token vectors
self.position_embedding_table = nn.Embedding(max_length, dim_embed)
# Create a series of transformer blocks
self.transformer_blocks = nn.Sequential(*[TranformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)])
self.transformer_blocks = nn.Sequential(*[TransformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)])
# Create a layer normalization layer for the final output
self.layer_norm_final = nn.LayerNorm(dim_embed)
# Create a linear layer to project the output from embedding space to vocabulary space
Expand All @@ -251,7 +251,8 @@ def forward(self, token_ids: Tensor, labels: Tensor = None, reduce_loss: bool =
B, T = token_ids.shape
# Get the token embedding and position embedding
token_embedding = self.token_embedding_table(token_ids) # (B, T) -> (B, T, dim_embed)
position_embedding = self.position_embedding_table(torch.arange(T, device=self.device)) # (T) -> (T, dim_embed)
# The absolute position embedding is quite old fashioned but it's good enough for our tutorial
position_embedding = self.position_embedding_table(torch.arange(T, device=token_ids.device)) # (T) -> (T, dim_embed)
# Add the token embedding and position embedding in the last dimension
embedding = token_embedding + position_embedding # (B, T, dim_embed) + (T, dim_embed) -> (B, T, dim_embed)
# Send the embedding through the transformer blocks
Expand All @@ -270,7 +271,8 @@ def forward(self, token_ids: Tensor, labels: Tensor = None, reduce_loss: bool =
# Flatten the labels to a list of token ids
labels = labels.view(B * T)
# Compute the cross-entropy loss between the logits and the labels
loss = F.cross_entropy(logits, labels, reduce=reduce_loss)
reduction = 'mean' if reduce_loss else 'none'
loss = F.cross_entropy(logits, labels, reduction=reduction)

return logits, loss

Expand Down
25 changes: 25 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import warnings
import torch

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
Expand All @@ -9,6 +10,30 @@
from model import TutorialLLM
from trainer import Trainer


def test_forward_reduce_loss_false_returns_unreduced_loss_without_warning():
"""
Test the unreduced loss path used by DPO without relying on dataset loading.
"""
torch.manual_seed(2024)
# Keep the constructor device intentionally stale to verify the positional indices follow the input tensor device.
model = TutorialLLM(vocabulary_size=17, dim_embed=8, max_length=4, num_head=2, num_layer=1, device='cuda')
token_ids = torch.randint(0, 17, (2, 4))
labels = torch.randint(0, 17, (2, 4))

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter('always')
logits, loss = model(token_ids, labels, reduce_loss=False)

assert logits.shape == (8, 17)
assert loss.shape == (8,)
assert not any(
'size_average' in str(warning.message) and 'reduce args' in str(warning.message)
for warning in caught
)



def test_run():
"""
Test the overal pipeline runs without error.
Expand Down
Loading