mastmelon/GPT2-124M
Folders and files
| Name | Name | Last commit date | ||
|---|---|---|---|---|
Repository files navigation
Differences with original transformer paper
1. No encoder
2. Layer normalization moved to the input of each sub-block
3. Additional layer normalization was added after the final self-attention block
4. Input flow through block unmodified via Residual connection
- Gradients from the top to the input via residual connection
Using approximate version of GELU (tanh) in MLP
- approximate version was fast. This is not the case anymore
- we are replicating GPT-2 exactly
GELU instead of RELU
- dead RELU neuron problem - Activation that fell on the tail of RELU get gradient = 0, no change or adaptation possible
Attention
- Multiplication by v ~ weighted sum of the values of the tokens that we found interesting after QK'
torch.compile
- Doesn't run in python interpreter style. Look at the complete code and optimize
- read/write in batches - optimize round trips to the HBM memory - kernel fusion?
flash attention
- kernel fusion operation
- mindful of memory hierarchy
- Matmul | Mask | Softmax | Dropout | Matmul -> Fused Kernel
- fewer reads and writes to HBM
- attention matrix
- uses - scaling online softmax
nice numbers
- kernel - blocks - computation of nice blocks - computation of leftovers
global norm clipping
- ?
Learning rate
- Linear LR warmup at first
- Cosine decay till a min value is reached
Weight decay and fused setting in AdamW
- avoiding extreme highs
- not implemented in code
- Hyperparameters are co-related
Gradient accumulation
- Large batch size (can't go small because hyperparameters are co-related)
- Run forward backward pass multiple times, accumulate gradient and then single update of parameters
- not implemented in code
Distributed DataParallel - ddp
- not implemented in code
Dataset
- Fineweb
- Fineweb edu - sample-10T
Next:
1. huggingface/transformer lib GPT2 impl - shttps://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
Parameters sharing token encoding and lm_head
- self.transformer.wte.weight = self.lm_head.weight # copying the data pointer or reference
- https://arxiv.org/abs/1608.05859