Skip to content

mastmelon/GPT2-124M

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Differences with original transformer paper
1. No encoder
2. Layer normalization moved to the input of each sub-block
3. Additional layer normalization was added after the final self-attention block
4. Input flow through block unmodified via Residual connection
    - Gradients from the top to the input via residual connection

Using approximate version of GELU (tanh) in MLP
- approximate version was fast. This is not the case anymore
- we are replicating GPT-2 exactly

GELU instead of RELU
- dead RELU neuron problem - Activation that fell on the tail of RELU get gradient = 0, no change or adaptation possible


Attention
- Multiplication by v ~ weighted sum of the values of the tokens that we found interesting after QK'

torch.compile
- Doesn't run in python interpreter style. Look at the complete code and optimize
- read/write in batches - optimize round trips to the HBM memory - kernel fusion?

flash attention
- kernel fusion operation
- mindful of memory hierarchy
- Matmul | Mask | Softmax | Dropout | Matmul -> Fused Kernel
- fewer reads and writes to HBM
- attention matrix
- uses - scaling online softmax

nice numbers
- kernel - blocks - computation of nice blocks - computation of leftovers

global norm clipping
- ?

Learning rate
- Linear LR warmup at first
- Cosine decay till a min value is reached

Weight decay and fused setting in AdamW
- avoiding extreme highs
- not implemented in code

- Hyperparameters are co-related

Gradient accumulation
- Large batch size (can't go small because hyperparameters are co-related)
- Run forward backward pass multiple times, accumulate gradient and then single update of parameters
- not implemented in code

Distributed DataParallel - ddp
- not implemented in code

Dataset
- Fineweb
- Fineweb edu - sample-10T

Next:
1. huggingface/transformer lib GPT2 impl - shttps://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

Parameters sharing token encoding and lm_head
- self.transformer.wte.weight = self.lm_head.weight # copying the data pointer or reference
- https://arxiv.org/abs/1608.05859





About

GPT2 124M Pytorch implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages