Run Whisper training with Google Cloud buckets#70
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
jqug
left a comment
There was a problem hiding this comment.
Thanks for this, looks good.
Just one thing, let's take out the gcloud auth for now and maybe mention in a comment in the file that this may be necessary.
whisper_training_setup.sh
Outdated
| sudo apt-get install -y google-cloud-cli | ||
|
|
||
| gcloud init | ||
| gcloud auth application-default login No newline at end of file |
There was a problem hiding this comment.
If we load from a public GCS bucket, is this step necessary?
We could include in the comments for the file that these gcloud lines should be run in order to access nonpublic buckets (but should only be done on trusted machines).
|
We should consider merging this notebook into the dedicated sunbird-speech repo: |
…ngual_eval_fn processing; fix label 448 limit; launch full training
| gradient_accumulation_steps: 4 | ||
| learning_rate: 1.0e-05 | ||
| warmup_steps: 500 | ||
| max_steps: 7500 |
There was a problem hiding this comment.
Maybe you already caught this, but I think this argument is why the training time looked like it was only going to be ~13 hours, because it would stop prematurely after this number of steps. (With our previous dataset this was about 5 epochs after which the model converged pretty well).
… fix preprocess error
This PR adds the support of google cloud buckets for the whisper training pipeline, and made several other changes:
gcs://path withdatasets.load_datasetand cast the audio column todatasets.Audioformat.salt.datasetsfrom the current repo instead of https://github.com/jqug/salt.gitgradient_checkpointing=Falsetorch_dtype=torch.float32when loading the model weightsmodel.generation_configbased on requirements from the new version.Overfit experiment
An overfit experiment with just 100 examples was done to verify the changes:
MLflow run1 with evaluation metrics: https://mlflow-sunbird-ce0ecfc14244.herokuapp.com/#/experiments/0/runs/2d488acdc39146e9af9da07c00128d49/model-metrics
MLfLow run2 with GPU utilization: https://mlflow.sunbird.ai/#/experiments/0/runs/811bbdf051f44597bd90c3376cfc9309/system-metrics
TODO
salt.constants.SALT_LANGUAGE_TOKENS_WHISPERto support new languages. Currently we only have the following: