Skip to content

UBCAgroBot/transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 

Repository files navigation

LLM Training: Transformer

Details: this project involves building and training Transformer models from scratch, with a focus on code quality and performance optimization. The team will primarily use PyTorch (the team can also choose to use Jax/Flax instead of Pytorch, but it will be harder) to implement the architecture, utilizing jaxtyping for tensor shape checking. The goal is to optimize for iteration time and scalability.

Team Size: 3

Progression

  1. implement the core Transformer architecture in PyTorch or Jax. (use tiktoken as tokenizer for now)
  2. train a baseline model on a standard dataset to verify correctness.
  3. optimize the training loop for throughput using torch.compile and mixed precision training.
  4. implement Distributed Data Parallel (or Jax equivalent) to enable multi-GPU training.
  5. (stretch) scale up the implementation (we can try ~80M parameters with 4x V100s for a couple of days).
  6. (stretch) implement Rotary Position Embeddings to improve relative position handling.

Technologies

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors