Skip to content

nickkipshidze/mini-bert

Repository files navigation

Mini BERT

This project was motivated by Stanford's CS224N course, encouraging me to make my own "default final project" which was meant to be made over weeks of time, but this project took only 3 days of work.

The speed was mostly because I skipped the architecture implementation, using PyTorch's transformer implementation instead, since I already spent enough time making my own from-scratch transformer in my previous projects. (e.g. nickkipshidze/mini-generative-transformer)

Here's a list of parts of this project that I made from-scratch:

  • BERT tokenizer (original implementation with near perfect decoding)
  • End-to-end pre-training pipeline
  • Masked language modeling
  • End-to-end fine-tuning pipeline
  • Inference pipeline
  • Demos and examples (for you to check out my model)

Training my model

The pre-training took 20 hours on my hardware. NVIDIA GeForce RTX 4060 Ti is the GPU I used for both pre-training and fine-tuning, and it was done fully offline, locally.

Pre-training loss curve plot on Wikipedia

The dataset I used for pre-training is a sample of Wikipedia from bwandowando/wikipedia-index-and-plaintext-20230801.

Here's a silly sample I find amusing:

  • Original sample: The quick brown fox jumps over a lazy dog.
  • Masked sample: The quick brown fox [MASK] over a lazy dog.
  • Model prediction: The quick brown fox flies over a lazy dog.

The [MASK] token should've been replaced with "jumps" but the model decided to get creative in a way.

To evaluate the pre-trained model, I fine-tuned it on a sentiment analysis task. I used the jp797498e/twitter-entity-sentiment-analysis dataset for fine-tuning. Here is the scikit-learn classification report to summarize the results after 88 minutes of training:

              precision    recall  f1-score   support

           0       0.99      0.99      0.99       266
           1       0.97      0.98      0.98       285
           2       0.98      0.98      0.98       277
           3       0.99      0.98      0.99       172

    accuracy                           0.98      1000
   macro avg       0.98      0.98      0.98      1000
weighted avg       0.98      0.98      0.98      1000

Compared to 1200 minutes of pre-training, 88 minutes of fine-tuning is computationally modest yet highly effective, achieving a 98% F1-score on a four class sentiment analysis task.

About

From-scratch implementation of MLM and full BERT pre-traininig and fine-tuning pipeline (as well as the BERT tokenizer) for educational purposes.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors