Details: this project involves building and training Transformer models from scratch, with a focus on code quality and performance optimization. The team will primarily use PyTorch (the team can also choose to use Jax/Flax instead of Pytorch, but it will be harder) to implement the architecture, utilizing jaxtyping for tensor shape checking. The goal is to optimize for iteration time and scalability.
Team Size: 3
Progression
- implement the core Transformer architecture in PyTorch or Jax. (use
tiktokenas tokenizer for now) - train a baseline model on a standard dataset to verify correctness.
- optimize the training loop for throughput using
torch.compileand mixed precision training. - implement Distributed Data Parallel (or Jax equivalent) to enable multi-GPU training.
- (stretch) scale up the implementation (we can try ~80M parameters with 4x V100s for a couple of days).
- (stretch) implement Rotary Position Embeddings to improve relative position handling.
Technologies
- Languages: Python
- Libraries & Frameworks: PyTorch, Jaxtyping
- References
- Primary Guide: Let's build GPT: from scratch, in code, spelled out
- Scaling up: Let's reproduce GPT-2
- Another cool guide: GPT in 60 Lines of NumPy