Fix deprecated loss reduction API and sync tutorial docs#3
Open
adny-code wants to merge 3 commits intojingedawang:mainfrom
Open
Fix deprecated loss reduction API and sync tutorial docs#3adny-code wants to merge 3 commits intojingedawang:mainfrom
adny-code wants to merge 3 commits intojingedawang:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the TutorialLLM implementation and tutorial materials to remove use of PyTorch’s deprecated F.cross_entropy(..., reduce=...) API, fixes the TransformerBlock naming typo, and adds a focused regression test for the unreduced-loss path used by DPO.
Changes:
- Replace deprecated
reduceargument inF.cross_entropywithreductionwhile preserving reduced vs unreduced behavior. - Rename
TranformerBlocktoTransformerBlockin the implementation and partially in the tutorial chapter. - Update dataset documentation text and add a regression test for
reduce_loss=False(unreduced loss) without the deprecation warning.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| tests/test_run.py | Adds a regression test for the unreduced loss path (reduce_loss=False) and checks for absence of the deprecated-loss warning. |
| model.py | Fixes TransformerBlock typo usage and switches cross-entropy to reduction=; includes minor comment edits. |
| dataset.py | Aligns documentation text with actual alignment negative sampling and clarifies batch padding/masking docs. |
| book/di-10-jie-dai-ma-shi-xian-zhong-de-zhong-dian-xuan-du.md | Updates tutorial snippet to use TransformerBlock in one location (but still has remaining inconsistent/outdated snippets). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
59
to
60
| # Scale the attention weights to avoid the problem of vanishing gradients. | ||
| weights *= dim_embed ** -0.5 |
| # Get the token embedding and position embedding | ||
| token_embedding = self.token_embedding_table(token_ids) # (B, T) -> (B, T, dim_embed) | ||
| # The absolute position embedding is quite old fashioned but it's good enough for our tutorial | ||
| position_embedding = self.position_embedding_table(torch.arange(T, device=self.device)) # (T) -> (T, dim_embed) |
|
|
||
| assert logits.shape == (8, 17) | ||
| assert loss.shape == (8,) | ||
| assert not any('size_average and reduce args will be deprecated' in str(w.message) for w in caught) |
Comment on lines
99
to
100
| # Note that we add a special character '\0' in the end, which is used as an end-of-text token(will be index 0 in the vocabulary). | ||
| # An end-of-text token is useful to let the model know when to stop generating text. |
Comment on lines
+232
to
+233
| Emplace 0 to the positions that exceed the actual length of each item, and mask these positions in the label by setting them to -100. | ||
| This is necessary to let the model know where to stop(first 0 in label) and ignore the rest padding tokens in the loss calculation. |
|
|
||
| Returns: | ||
| A batch of input token id lists and label token ids. The label refer to the next character of each input sequence | ||
| A batch of input token id lists and label token ids. The label refer to the next character of each input sequence. |
| self.token_embedding_table = nn.Embedding(vocabulary_size, dim_embed) | ||
| self.position_embedding_table = nn.Embedding(max_length, dim_embed) | ||
| self.transformer_blocks = nn.Sequential(*[TranformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)]) | ||
| self.transformer_blocks = nn.Sequential(*[TransformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)]) |
| self.token_embedding_table = nn.Embedding(vocabulary_size, dim_embed) | ||
| self.position_embedding_table = nn.Embedding(max_length, dim_embed) | ||
| self.transformer_blocks = nn.Sequential(*[TranformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)]) | ||
| self.transformer_blocks = nn.Sequential(*[TransformerBlock(dim_embed, num_head, max_length) for _ in range(num_layer)]) |
|
|
||
| assert logits.shape == (8, 17) | ||
| assert loss.shape == (8,) | ||
| assert not any('size_average and reduce args will be deprecated' in str(w.message) for w in caught) |
| @@ -89,7 +89,8 @@ def __init__(self, dim_embed: int, num_heads: int, head_size: int, max_length: i | |||
| self.heads = nn.ModuleList([AttentionHead(dim_embed, head_size, max_length) for _ in range(num_heads)]) | |||
| # Create a linear layer to project the concatenated output of all heads to the original dimension. | |||
| # In our case, the concatenated output is happen to be the same as the original dimension, so we can skip | |||
Comment on lines
99
to
100
| # Note that we add a special character '\0' in the end, which is used as an end-of-text token(will be index 0 in the vocabulary). | ||
| # An end-of-text token is useful to let the model know when to stop generating text. |
Comment on lines
231
to
234
| Process a batch of token id lists. | ||
| Emplace 0 to the positions that exceed the actual length of each item, and mask these positions in the label by setting them to -100. | ||
| This is necessary to let the model know where to stop(first 0 in label) and ignore the rest padding tokens in the loss calculation. | ||
|
|
|
|
||
| Returns: | ||
| A batch of input token id lists and label token ids. The label refer to the next character of each input sequence | ||
| A batch of input token id lists and label token ids. The label refer to the next character of each input sequence. |
Author
|
Addressed the current review feedback in 23ed735.
|
| Test the unreduced loss path used by DPO without relying on dataset loading. | ||
| """ | ||
| torch.manual_seed(2024) | ||
| # Keep the constructor device intentionally stale to verify forward uses the input tensor device. |
|
|
||
| assert logits.shape == (8, 17) | ||
| assert loss.shape == (8,) | ||
| assert not caught |
Author
|
Followed up on the latest automated review in b47f7b4.
|
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (1)
tests/test_run.py:40
- Docstring typo: "overal" should be "overall".
"""
Test the overal pipeline runs without error.
"""
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
reduceargument inF.cross_entropywithreductionTransformerBlocktypo in the implementation and tutorial snippetdataset.pywith the actual sampling logicreduce_loss=Falsepath used by DPOWhy
The code currently relies on
reduce_loss=Falsewhen computing DPO rewards. Using the deprecatedreduceargument still works, but it emits a PyTorch deprecation warning. This change keeps the behavior the same while making the code forward-compatible.The tutorial chapter also needs to stay consistent with the implementation so readers do not copy an outdated class name.
Validation
reduce_loss=Falsesuccessfully in thev100environment