diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index a7f4f0a902..aa15123d8e 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -19,7 +19,7 @@ Steps to reproduce the behavior (**always include the command you ran**): #### Code sample - ### Expected behavior @@ -28,7 +28,7 @@ Minimal means having the shortest code but still preserving the bug. --> ### Environment - - fairseq Version (e.g., 1.0 or master): + - fairseq Version (e.g., 1.0 or main): - PyTorch Version (e.g., 1.0) - OS (e.g., Linux): - How you installed fairseq (`pip`, source): diff --git a/.github/ISSUE_TEMPLATE/how-to-question.md b/.github/ISSUE_TEMPLATE/how-to-question.md index 4beb180dbf..04f3f15d3e 100644 --- a/.github/ISSUE_TEMPLATE/how-to-question.md +++ b/.github/ISSUE_TEMPLATE/how-to-question.md @@ -6,9 +6,9 @@ labels: 'question, needs triage' ## ❓ Questions and Help -### Before asking: -1. search the issues. -2. search the docs. +### Before asking: +1. search the issues. +2. search the docs. @@ -16,13 +16,13 @@ labels: 'question, needs triage' #### Code - + #### What have you tried? #### What's your environment? - - fairseq Version (e.g., 1.0 or master): + - fairseq Version (e.g., 1.0 or main): - PyTorch Version (e.g., 1.0) - OS (e.g., Linux): - How you installed fairseq (`pip`, source): diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index b28ff98e7b..d005e2df4f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,15 +1,15 @@ # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) -- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? -- [ ] Did you make sure to update the docs? -- [ ] Did you write any new necessary tests? +- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? +- [ ] Did you make sure to update the docs? +- [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). -## PR review -Anyone in the community is free to review the PR once the tests have passed. +## PR review +Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000000..b12867dab0 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,30 @@ +# Configuration for probot-stale - https://github.com/probot/stale +# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 90 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - bug +# Label to use when marking an issue as stale +staleLabel: stale +issues: + # Comment to post when marking an issue as stale. + markComment: > + This issue has been automatically marked as stale. + **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment! + # Comment to post when closing a stale issue. + closeComment: > + Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you! +pulls: + # Comment to post when marking a pull request as stale. + markComment: > + This pull request has been automatically marked as stale. + **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated. + # Comment to post when closing a stale pull request. + closeComment: > + Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you! + diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6ae8093a8a..f493f91f0d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,10 +1,10 @@ name: build on: - # Trigger the workflow on push to master or any pull request + # Trigger the workflow on push to main or any pull request push: branches: - - master + - main pull_request: jobs: @@ -19,26 +19,37 @@ jobs: runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Conditionally install pytorch if: matrix.platform == 'windows-latest' run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html + - name: Install locally run: | python -m pip install --upgrade pip + git submodule update --init --recursive python setup.py build_ext --inplace python -m pip install --editable . + + - name: Install optional test requirements + run: | + python -m pip install iopath transformers pyarrow + python -m pip install git+https://github.com/facebookresearch/fairscale.git@master + - name: Lint with flake8 run: | pip install flake8 # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron + - name: Run tests run: | python setup.py test diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml new file mode 100644 index 0000000000..7261708596 --- /dev/null +++ b/.github/workflows/build_wheels.yml @@ -0,0 +1,41 @@ +name: build_wheels + +on: + push: + branches: + - v[0-9]+.[0-9]+.[x0-9]+ + tags: + - v* + +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v2 + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.7' + + - name: Install cibuildwheel + run: | + python -m pip install cibuildwheel + + - name: Build wheels for CPython + run: | + python -m cibuildwheel --output-dir dist + env: + CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64" + CIBW_MANYLINUX_X86_64_IMAGE: manylinux1 + CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install . + + - uses: actions/upload-artifact@v2 + with: + name: wheels + path: ./dist/*.whl diff --git a/.github/workflows/build_windows.yml b/.github/workflows/build_windows.yml deleted file mode 100644 index 3161fd09c7..0000000000 --- a/.github/workflows/build_windows.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: build_windows - -on: - # Trigger the workflow on push to master or any pull request - push: - branches: - - master - pull_request: - -jobs: - build: - - strategy: - max-parallel: 4 - matrix: - platform: [windows-latest] - python-version: [3.6, 3.7] - - runs-on: ${{ matrix.platform }} - - steps: - - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Conditionally install pytorch - if: matrix.platform == 'windows-latest' - run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html - - name: Install locally - run: | - python -m pip install --upgrade pip - python setup.py build_ext --inplace - python -m pip install --editable . - - name: Lint with flake8 - run: | - pip install flake8 - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Run tests - run: | - python setup.py test diff --git a/.gitignore b/.gitignore index a7c4577149..4112804793 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ ENV/ /fairseq/temporal_convolution_tbc /fairseq/modules/*_layer/*_forward.cu /fairseq/modules/*_layer/*_backward.cu +/fairseq/version.py # data data-bin/ @@ -130,3 +131,6 @@ data-bin/ # Experimental Folder experimental/* + +# Weights and Biases logs +wandb/ diff --git a/.gitmodules b/.gitmodules index df0d3d3071..07a55d45d4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ -[submodule "fairseq/models/huggingface/transformers"] - path = fairseq/models/huggingface/transformers - url = https://github.com/myleott/transformers.git - branch = fairseq [submodule "fairseq/model_parallel/megatron"] path = fairseq/model_parallel/megatron url = https://github.com/ngoyal2707/Megatron-LM diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4d7ca6a98e..3930c46196 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ possible. ## Pull Requests We actively welcome your pull requests. -1. Fork the repo and create your branch from `master`. +1. Fork the repo and create your branch from `main`. 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update the documentation. 4. Ensure the test suite passes. diff --git a/README.md b/README.md index 8542791c2f..dd68717480 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@
-
+
+* **Convolutional Neural Networks (CNN)**
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
+* **LightConv and DynamicConv models**
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
+* **Long Short-Term Memory (LSTM) networks**
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
+* **Transformer (self-attention) networks**
+ + Attention Is All You Need (Vaswani et al., 2017)
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
+* **Non-autoregressive Transformers**
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
+* **Finetuning**
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
+
+
+
+* September 2020: [Added Linformer code](examples/linformer/README.md)
+* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
+* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
+* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
+* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
+* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
+* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
+* April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
+* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
+* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
+* February 2020: [mBART model and code released](examples/mbart/README.md)
+* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
+* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
+* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
+* November 2019: [CamemBERT model and code released](examples/camembert/README.md)
+* November 2019: [BART model and code released](examples/bart/README.md)
+* November 2019: [XLM-R models and code released](examples/xlmr/README.md)
+* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
+* August 2019: [WMT'19 models released](examples/wmt19/README.md)
+* July 2019: fairseq relicensed under MIT license
+* July 2019: [RoBERTa models and code released](examples/roberta/README.md)
+* June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
+
+
+
+
+FSDP currently has several limitations compared to fairseq's default DDP backend (PyTorch DDP):
+* while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
+* FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported
+
+See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
+explanation of these and other limitations.
+
+
+
+
+
+```
+(...)
+2021-03-08 12:29:51 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
+(...)
+2021-03-08 12:29:51 | INFO | fairseq_cli.train | training on 1 devices (GPUs/TPUs)
+2021-03-08 12:29:51 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
+(...)
+Adam Optimizer #0 is created with AVX2 arithmetic capability.
+Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
+(...)
+2021-03-08 12:31:36 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.475", "ppl": "91120.8", "wps": "0", "ups": "0", "wpb": "16384", "bsz": "8", "num_updates": "1", "lr": "2e-05", "gnorm": "20.751", "loss_scale": "4", "train_wall": "99", "gb_free": "9.3", "wall": "105"}
+2021-03-08 12:32:33 | INFO | train_inner | {"epoch": 1, "update": 0.0, "loss": "16.446", "ppl": "89281.6", "wps": "288.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "2", "lr": "4e-05", "gnorm": "19.777", "loss_scale": "4", "train_wall": "57", "gb_free": "9.3", "wall": "161"}
+2021-03-08 12:33:12 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
+2021-03-08 12:33:51 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
+2021-03-08 12:34:45 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "25.22", "ppl": "3.90691e+07", "wps": "123.4", "ups": "0.01", "wpb": "16384", "bsz": "8", "num_updates": "3", "lr": "6e-05", "gnorm": "131.281", "loss_scale": "1", "train_wall": "133", "gb_free": "9.3", "wall": "294"}
+2021-03-08 12:35:43 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.079", "ppl": "276809", "wps": "285.5", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "4", "lr": "8e-05", "gnorm": "13.776", "loss_scale": "1", "train_wall": "57", "gb_free": "9.3", "wall": "351"}
+2021-03-08 12:36:35 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "23.729", "ppl": "1.39088e+07", "wps": "316.7", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "5", "lr": "0.0001", "gnorm": "72.774", "loss_scale": "1", "train_wall": "52", "gb_free": "9.3", "wall": "403"}
+2021-03-08 12:37:28 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "20.429", "ppl": "1.41203e+06", "wps": "307.6", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "6", "lr": "8e-05", "gnorm": "60.846", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "456"}
+2021-03-08 12:38:27 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.965", "ppl": "511684", "wps": "279.4", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "7", "lr": "6e-05", "gnorm": "22.687", "loss_scale": "1", "train_wall": "59", "gb_free": "9.3", "wall": "515"}
+2021-03-08 12:39:18 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "18.345", "ppl": "332887", "wps": "319.1", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "8", "lr": "4e-05", "gnorm": "8.451", "loss_scale": "1", "train_wall": "51", "gb_free": "9.3", "wall": "566"}
+2021-03-08 12:40:11 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "18.262", "ppl": "314336", "wps": "305.9", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "9", "lr": "2e-05", "gnorm": "6.457", "loss_scale": "1", "train_wall": "54", "gb_free": "9.3", "wall": "620"}
+2021-03-08 12:41:04 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "17.556", "ppl": "192686", "wps": "311.8", "ups": "0.02", "wpb": "16384", "bsz": "8", "num_updates": "10", "lr": "0", "gnorm": "5.796", "loss_scale": "1", "train_wall": "53", "gb_free": "9.3", "wall": "673"}
+2021-03-08 12:41:04 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
+2021-03-08 12:41:04 | INFO | fairseq_cli.train | begin validation on "valid" subset
+2021-03-08 12:43:15 | INFO | valid | {"epoch": 1, "valid_loss": "17.953", "valid_ppl": "253807", "valid_wps": "1868.4", "valid_wpb": "15400.2", "valid_bsz": "7.6", "valid_num_updates": "10"}
+2021-03-08 12:43:15 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
+2021-03-08 12:43:15 | INFO | train | {"epoch": 1, "train_loss": "19.351", "train_ppl": "668509", "train_wps": "210.9", "train_ups": "0.01", "train_wpb": "16384", "train_bsz": "8", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "36.26", "train_loss_scale": "1", "train_train_wall": "667", "train_gb_free": "9.3", "train_wall": "804"}
+2021-03-08 12:43:15 | INFO | fairseq_cli.train | done training in 798.6 seconds
+```
+
+
+
+```
+(...)
+2021-03-08 18:04:09 | INFO | fairseq_cli.train | num. model params: 13,110,865,920 (num. trained: 13,110,865,920)
+(...)
+2021-03-08 18:04:09 | INFO | fairseq_cli.train | training on 8 devices (GPUs/TPUs)
+2021-03-08 18:04:09 | INFO | fairseq_cli.train | max tokens per GPU = None and batch size per GPU = 8
+(...)
+Adam Optimizer #0 is created with AVX2 arithmetic capability.
+Config: alpha=0.000100, betas=(0.900000, 0.980000), weight_decay=0.000000, adam_w=1
+(...)
+2021-03-08 18:05:06 | INFO | train_inner | {"epoch": 1, "update": 0.001, "loss": "16.408", "ppl": "86945.6", "wps": "0", "ups": "0", "wpb": "131072", "bsz": "64", "num_updates": "1", "lr": "2e-05", "gnorm": "18.27", "loss_scale": "4", "train_wall": "47", "gb_free": "9.3", "wall": "56"}
+2021-03-08 18:05:45 | INFO | train_inner | {"epoch": 1, "update": 0.002, "loss": "16.352", "ppl": "83644.3", "wps": "3283.4", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "2", "lr": "4e-05", "gnorm": "18.411", "loss_scale": "4", "train_wall": "40", "gb_free": "9.3", "wall": "96"}
+2021-03-08 18:06:21 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 2.0
+2021-03-08 18:06:56 | INFO | fairseq.trainer | NOTE: gradient overflow detected, ignoring gradient, setting loss scale to: 1.0
+2021-03-08 18:07:37 | INFO | train_inner | {"epoch": 1, "update": 0.006, "loss": "23.682", "ppl": "1.34537e+07", "wps": "1176.6", "ups": "0.01", "wpb": "131072", "bsz": "64", "num_updates": "3", "lr": "6e-05", "gnorm": "119.682", "loss_scale": "1", "train_wall": "111", "gb_free": "9.3", "wall": "208"}
+2021-03-08 18:08:18 | INFO | train_inner | {"epoch": 1, "update": 0.007, "loss": "18.988", "ppl": "519921", "wps": "3189.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "4", "lr": "8e-05", "gnorm": "14.934", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "249"}
+2021-03-08 18:08:59 | INFO | train_inner | {"epoch": 1, "update": 0.008, "loss": "20.08", "ppl": "1.10798e+06", "wps": "3223.1", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "5", "lr": "0.0001", "gnorm": "59.92", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "289"}
+2021-03-08 18:09:39 | INFO | train_inner | {"epoch": 1, "update": 0.009, "loss": "18.323", "ppl": "327980", "wps": "3256.6", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "6", "lr": "8e-05", "gnorm": "37.425", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "330"}
+2021-03-08 18:10:20 | INFO | train_inner | {"epoch": 1, "update": 0.01, "loss": "17.264", "ppl": "157354", "wps": "3188.7", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "7", "lr": "6e-05", "gnorm": "10.824", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "371"}
+2021-03-08 18:11:01 | INFO | train_inner | {"epoch": 1, "update": 0.011, "loss": "16.794", "ppl": "113647", "wps": "3230", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "8", "lr": "4e-05", "gnorm": "5.616", "loss_scale": "1", "train_wall": "41", "gb_free": "9.3", "wall": "411"}
+2021-03-08 18:11:39 | INFO | train_inner | {"epoch": 1, "update": 0.012, "loss": "16.706", "ppl": "106938", "wps": "3384", "ups": "0.03", "wpb": "131072", "bsz": "64", "num_updates": "9", "lr": "2e-05", "gnorm": "5.318", "loss_scale": "1", "train_wall": "39", "gb_free": "9.3", "wall": "450"}
+2021-03-08 18:12:19 | INFO | train_inner | {"epoch": 1, "update": 0.013, "loss": "16.548", "ppl": "95796.2", "wps": "3274.4", "ups": "0.02", "wpb": "131072", "bsz": "64", "num_updates": "10", "lr": "0", "gnorm": "5.22", "loss_scale": "1", "train_wall": "40", "gb_free": "9.3", "wall": "490"}
+2021-03-08 18:12:19 | INFO | fairseq_cli.train | Stopping training due to num_updates: 10 >= max_update: 10
+2021-03-08 18:12:19 | INFO | fairseq_cli.train | begin validation on "valid" subset
+2021-03-08 18:12:45 | INFO | valid | {"epoch": 1, "valid_loss": "16.624", "valid_ppl": "101000", "valid_wps": "10855.9", "valid_wpb": "123202", "valid_bsz": "60.5", "valid_num_updates": "10"}
+2021-03-08 18:12:45 | INFO | fairseq_cli.train | end of epoch 1 (average epoch stats below)
+2021-03-08 18:12:45 | INFO | train | {"epoch": 1, "train_loss": "18.114", "train_ppl": "283776", "train_wps": "2567.8", "train_ups": "0.02", "train_wpb": "131072", "train_bsz": "64", "train_num_updates": "10", "train_lr": "0", "train_gnorm": "29.562", "train_loss_scale": "1", "train_train_wall": "480", "train_gb_free": "9.3", "train_wall": "516"}
+2021-03-08 18:12:45 | INFO | fairseq_cli.train | done training in 509.9 seconds
+```
+
+
+
-
+
@@ -14,114 +14,159 @@ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
developers to train custom models for translation, summarization, language
modeling and other text generation tasks.
-### What's New:
+We provide reference implementations of various sequence modeling papers:
+
+
List of implemented papers
Previous updates
+Limitations
How it works
+
+See the [fairscale docs](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html) for a more detailed
+explanation of how FSDP works.
+
+Example output
Example output
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
-Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
+Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
+Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
## Training an LM with adaptive inputs
-First, see the general [language modeling README](../README.md) for instructions
-on preprocessing the WikiText-103 data.
+First, see the general [language modeling README](README.md) for instructions on
+preprocessing the WikiText-103 data.
Then use the following training command to train a model with adaptive inputs
using the `transformer_lm_wiki103` model architecture:
@@ -19,10 +19,10 @@ fairseq-train --task language_modeling \
data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 \
--arch transformer_lm_wiki103 \
- --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
- --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
+ --max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
+ --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
- --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
+ --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=legacy_ddp
```
## Citation
diff --git a/examples/language_model/conv_lm/README.md b/examples/language_model/README.conv.md
similarity index 76%
rename from examples/language_model/conv_lm/README.md
rename to examples/language_model/README.conv.md
index 83ac0b454b..1ff8635906 100644
--- a/examples/language_model/conv_lm/README.md
+++ b/examples/language_model/README.conv.md
@@ -2,8 +2,7 @@
## Example usage
-First download and preprocess the data following the main [language modeling
-README](../README.md).
+First download and preprocess the data following the main [language modeling README](README.md).
Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
architecture:
@@ -12,11 +11,14 @@ fairseq-train --task language_modeling \
data-bin/wikitext-103 \
--save-dir checkpoints/fconv_wikitext-103 \
--arch fconv_lm_dauphin_wikitext103 \
- --max-epoch 35 \ --optimizer nag \
+ --adaptive-softmax-cutoff 10000,20000,200000 \
+ --dropout 0.2 \
+ --criterion adaptive_loss \
+ --optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
- --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
- --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024 \
- --ddp-backend=no_c10d
+ --max-tokens 1024 --tokens-per-sample 1024 \
+ --ddp-backend legacy_ddp \
+ --max-epoch 35
```
And evaluate with:
diff --git a/examples/language_model/README.md b/examples/language_model/README.md
index 43f3381a1f..e78ea48e08 100644
--- a/examples/language_model/README.md
+++ b/examples/language_model/README.md
@@ -5,7 +5,7 @@
Model | Description | Dataset | Download
---|---|---|---
`transformer_lm.gbw.adaptive_huge` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
-`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
+`transformer_lm.wiki103.adaptive` | Adaptive Inputs
([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853))
247M params | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
`transformer_lm.wmt19.en` | English LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz)
`transformer_lm.wmt19.de` | German LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz)
`transformer_lm.wmt19.ru` | Russian LM
([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz)
@@ -72,8 +72,7 @@ fairseq-preprocess \
### 2) Train a language model
Next we'll train a basic transformer language model on wikitext-103. For more
-advanced examples (e.g., using [adaptive inputs](https://arxiv.org/abs/1809.10853)),
-please see the [Transformer LM README](transformer_lm/README.md).
+advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md).
To train a basic LM (assumes 2 GPUs):
```
@@ -100,7 +99,7 @@ number of GPUs.
```bash
fairseq-eval-lm data-bin/wikitext-103 \
--path checkpoints/transformer_wiki103/checkpoint_best.pt \
- --max-sentences 2 \
+ --batch-size 2 \
--tokens-per-sample 512 \
--context-window 400
# | Evaluated 245569 tokens in 56.1s (4379.02 tokens/s)
@@ -120,5 +119,5 @@ dataset, but results in better (lower) perplexity.
## Convolutional language models
-Please see the [convolutional LM README](conv_lm/README.md) for instructions to
-train convolutional language models.
+Please see the [convolutional LM README](README.conv.md) for instructions on
+training convolutional language models.
diff --git a/examples/laser/README.md b/examples/laser/README.md
new file mode 100644
index 0000000000..66acada04f
--- /dev/null
+++ b/examples/laser/README.md
@@ -0,0 +1,144 @@
+# LASER Language-Agnostic SEntence Representations
+
+LASER is a library to calculate and use multilingual sentence embeddings.
+
+You can find more information about LASER and how to use it on the official [LASER repository](https://github.com/facebookresearch/LASER).
+
+This folder contains source code for training LASER embeddings.
+
+
+## Prepare data and configuration file
+
+Binarize your data with fairseq, as described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing).
+
+Create a json config file with this format:
+```
+{
+ "src_vocab": "/path/to/spm.src.cvocab",
+ "tgt_vocab": "/path/to/spm.tgt.cvocab",
+ "train": [
+ {
+ "type": "translation",
+ "id": 0,
+ "src": "/path/to/srclang1-tgtlang0/train.srclang1",
+ "tgt": "/path/to/srclang1-tgtlang0/train.tgtlang0"
+ },
+ {
+ "type": "translation",
+ "id": 1,
+ "src": "/path/to/srclang1-tgtlang1/train.srclang1",
+ "tgt": "/path/to/srclang1-tgtlang1/train.tgtlang1"
+ },
+ {
+ "type": "translation",
+ "id": 0,
+ "src": "/path/to/srclang2-tgtlang0/train.srclang2",
+ "tgt": "/path/to/srclang2-tgtlang0/train.tgtlang0"
+ },
+ {
+ "type": "translation",
+ "id": 1,
+ "src": "/path/to/srclang2-tgtlang1/train.srclang2",
+ "tgt": "/path/to/srclang2-tgtlang1/train.tgtlang1"
+ },
+ ...
+ ],
+ "valid": [
+ {
+ "type": "translation",
+ "id": 0,
+ "src": "/unused",
+ "tgt": "/unused"
+ }
+ ]
+}
+```
+where paths are paths to binarized indexed fairseq dataset files.
+`id` represents the target language id.
+
+
+## Training Command Line Example
+
+```
+fairseq-train \
+ /path/to/configfile_described_above.json \
+ --user-dir examples/laser/laser_src \
+ --log-interval 100 --log-format simple \
+ --task laser --arch laser_lstm \
+ --save-dir . \
+ --optimizer adam \
+ --lr 0.001 \
+ --lr-scheduler inverse_sqrt \
+ --clip-norm 5 \
+ --warmup-updates 90000 \
+ --update-freq 2 \
+ --dropout 0.0 \
+ --encoder-dropout-out 0.1 \
+ --max-tokens 2000 \
+ --max-epoch 50 \
+ --encoder-bidirectional \
+ --encoder-layers 5 \
+ --encoder-hidden-size 512 \
+ --decoder-layers 1 \
+ --decoder-hidden-size 2048 \
+ --encoder-embed-dim 320 \
+ --decoder-embed-dim 320 \
+ --decoder-lang-embed-dim 32 \
+ --warmup-init-lr 0.001 \
+ --disable-validation
+```
+
+
+## Applications
+
+We showcase several applications of multilingual sentence embeddings
+with code to reproduce our results (in the directory "tasks").
+
+* [**Cross-lingual document classification**](https://github.com/facebookresearch/LASER/tree/master/tasks/mldoc) using the
+ [*MLDoc*](https://github.com/facebookresearch/MLDoc) corpus [2,6]
+* [**WikiMatrix**](https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix)
+ Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia [7]
+* [**Bitext mining**](https://github.com/facebookresearch/LASER/tree/master/tasks/bucc) using the
+ [*BUCC*](https://comparable.limsi.fr/bucc2018/bucc2018-task.html) corpus [3,5]
+* [**Cross-lingual NLI**](https://github.com/facebookresearch/LASER/tree/master/tasks/xnli)
+ using the [*XNLI*](https://www.nyu.edu/projects/bowman/xnli/) corpus [4,5,6]
+* [**Multilingual similarity search**](https://github.com/facebookresearch/LASER/tree/master/tasks/similarity) [1,6]
+* [**Sentence embedding of text files**](https://github.com/facebookresearch/LASER/tree/master/tasks/embed)
+ example how to calculate sentence embeddings for arbitrary text files in any of the supported language.
+
+**For all tasks, we use exactly the same multilingual encoder, without any task specific optimization or fine-tuning.**
+
+
+
+## References
+
+[1] Holger Schwenk and Matthijs Douze,
+ [*Learning Joint Multilingual Sentence Representations with Neural Machine Translation*](https://aclanthology.info/papers/W17-2619/w17-2619),
+ ACL workshop on Representation Learning for NLP, 2017
+
+[2] Holger Schwenk and Xian Li,
+ [*A Corpus for Multilingual Document Classification in Eight Languages*](http://www.lrec-conf.org/proceedings/lrec2018/pdf/658.pdf),
+ LREC, pages 3548-3551, 2018.
+
+[3] Holger Schwenk,
+ [*Filtering and Mining Parallel Data in a Joint Multilingual Space*](http://aclweb.org/anthology/P18-2037)
+ ACL, July 2018
+
+[4] Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk and Veselin Stoyanov,
+ [*XNLI: Cross-lingual Sentence Understanding through Inference*](https://aclweb.org/anthology/D18-1269),
+ EMNLP, 2018.
+
+[5] Mikel Artetxe and Holger Schwenk,
+ [*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136)
+ arXiv, Nov 3 2018.
+
+[6] Mikel Artetxe and Holger Schwenk,
+ [*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464)
+ arXiv, Dec 26 2018.
+
+[7] Holger Schwenk, Vishrav Chaudhary, Shuo Sun, Hongyu Gong and Paco Guzman,
+ [*WikiMatrix: Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia*](https://arxiv.org/abs/1907.05791)
+ arXiv, July 11 2019.
+
+[8] Holger Schwenk, Guillaume Wenzek, Sergey Edunov, Edouard Grave and Armand Joulin
+ [*CCMatrix: Mining Billions of High-Quality Parallel Sentences on the WEB*](https://arxiv.org/abs/1911.04944)
diff --git a/validate.py b/examples/laser/laser_src/__init__.py
similarity index 61%
rename from validate.py
rename to examples/laser/laser_src/__init__.py
index 9c1c66bba5..9ffbd656d8 100644
--- a/validate.py
+++ b/examples/laser/laser_src/__init__.py
@@ -1,11 +1,8 @@
-#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from fairseq_cli.validate import cli_main
-
-
-if __name__ == '__main__':
- cli_main()
+from .laser_task import * # noqa
+from .laser_lstm import * # noqa
+from .laser_transformer import * # noqa
diff --git a/examples/laser/laser_src/laser_lstm.py b/examples/laser/laser_src/laser_lstm.py
new file mode 100644
index 0000000000..10df90e002
--- /dev/null
+++ b/examples/laser/laser_src/laser_lstm.py
@@ -0,0 +1,585 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from fairseq import options, utils
+
+from fairseq.models import (
+ FairseqEncoder,
+ FairseqIncrementalDecoder,
+ FairseqEncoderDecoderModel,
+ register_model,
+ register_model_architecture,
+)
+
+
+@register_model("laser_lstm")
+class LSTMModel(FairseqEncoderDecoderModel):
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths,
+ prev_output_tokens=None,
+ tgt_tokens=None,
+ tgt_lengths=None,
+ target_language_id=None,
+ dataset_name="",
+ ):
+ assert target_language_id is not None
+
+ src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name)
+ return self.decoder(
+ prev_output_tokens, src_encoder_out, lang_id=target_language_id
+ )
+
+ @staticmethod
+ def add_args(parser):
+ """Add model-specific arguments to the parser."""
+ parser.add_argument(
+ "--dropout",
+ default=0.1,
+ type=float,
+ metavar="D",
+ help="dropout probability",
+ )
+ parser.add_argument(
+ "--encoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="encoder embedding dimension",
+ )
+ parser.add_argument(
+ "--encoder-embed-path",
+ default=None,
+ type=str,
+ metavar="STR",
+ help="path to pre-trained encoder embedding",
+ )
+ parser.add_argument(
+ "--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size"
+ )
+ parser.add_argument(
+ "--encoder-layers", type=int, metavar="N", help="number of encoder layers"
+ )
+ parser.add_argument(
+ "--encoder-bidirectional",
+ action="store_true",
+ help="make all layers of encoder bidirectional",
+ )
+ parser.add_argument(
+ "--decoder-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-embed-path",
+ default=None,
+ type=str,
+ metavar="STR",
+ help="path to pre-trained decoder embedding",
+ )
+ parser.add_argument(
+ "--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size"
+ )
+ parser.add_argument(
+ "--decoder-layers", type=int, metavar="N", help="number of decoder layers"
+ )
+ parser.add_argument(
+ "--decoder-out-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder output embedding dimension",
+ )
+ parser.add_argument(
+ "--decoder-zero-init",
+ type=str,
+ metavar="BOOL",
+ help="initialize the decoder hidden/cell state to zero",
+ )
+ parser.add_argument(
+ "--decoder-lang-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder language embedding dimension",
+ )
+ parser.add_argument(
+ "--fixed-embeddings",
+ action="store_true",
+ help="keep embeddings fixed (ENCODER ONLY)",
+ ) # TODO Also apply to decoder embeddings?
+
+ # Granular dropout settings (if not specified these default to --dropout)
+ parser.add_argument(
+ "--encoder-dropout-in",
+ type=float,
+ metavar="D",
+ help="dropout probability for encoder input embedding",
+ )
+ parser.add_argument(
+ "--encoder-dropout-out",
+ type=float,
+ metavar="D",
+ help="dropout probability for encoder output",
+ )
+ parser.add_argument(
+ "--decoder-dropout-in",
+ type=float,
+ metavar="D",
+ help="dropout probability for decoder input embedding",
+ )
+ parser.add_argument(
+ "--decoder-dropout-out",
+ type=float,
+ metavar="D",
+ help="dropout probability for decoder output",
+ )
+
+ @classmethod
+ def build_model(cls, args, task):
+ """Build a new model instance."""
+ # make sure that all args are properly defaulted (in case there are any new ones)
+ base_architecture(args)
+
+ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+ embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
+ embed_dict = utils.parse_embedding(embed_path)
+ utils.print_embed_overlap(embed_dict, dictionary)
+ return utils.load_embedding(embed_dict, dictionary, embed_tokens)
+
+ pretrained_encoder_embed = None
+ if args.encoder_embed_path:
+ pretrained_encoder_embed = load_pretrained_embedding_from_file(
+ args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim
+ )
+ pretrained_decoder_embed = None
+ if args.decoder_embed_path:
+ pretrained_decoder_embed = load_pretrained_embedding_from_file(
+ args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim
+ )
+
+ num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
+
+ encoder = LSTMEncoder(
+ dictionary=task.source_dictionary,
+ embed_dim=args.encoder_embed_dim,
+ hidden_size=args.encoder_hidden_size,
+ num_layers=args.encoder_layers,
+ dropout_in=args.encoder_dropout_in,
+ dropout_out=args.encoder_dropout_out,
+ bidirectional=args.encoder_bidirectional,
+ pretrained_embed=pretrained_encoder_embed,
+ fixed_embeddings=args.fixed_embeddings,
+ )
+ decoder = LSTMDecoder(
+ dictionary=task.target_dictionary,
+ embed_dim=args.decoder_embed_dim,
+ hidden_size=args.decoder_hidden_size,
+ out_embed_dim=args.decoder_out_embed_dim,
+ num_layers=args.decoder_layers,
+ dropout_in=args.decoder_dropout_in,
+ dropout_out=args.decoder_dropout_out,
+ zero_init=options.eval_bool(args.decoder_zero_init),
+ encoder_embed_dim=args.encoder_embed_dim,
+ encoder_output_units=encoder.output_units,
+ pretrained_embed=pretrained_decoder_embed,
+ num_langs=num_langs,
+ lang_embed_dim=args.decoder_lang_embed_dim,
+ )
+ return cls(encoder, decoder)
+
+
+class LSTMEncoder(FairseqEncoder):
+ """LSTM encoder."""
+
+ def __init__(
+ self,
+ dictionary,
+ embed_dim=512,
+ hidden_size=512,
+ num_layers=1,
+ dropout_in=0.1,
+ dropout_out=0.1,
+ bidirectional=False,
+ left_pad=True,
+ pretrained_embed=None,
+ padding_value=0.0,
+ fixed_embeddings=False,
+ ):
+ super().__init__(dictionary)
+ self.num_layers = num_layers
+ self.dropout_in = dropout_in
+ self.dropout_out = dropout_out
+ self.bidirectional = bidirectional
+ self.hidden_size = hidden_size
+
+ num_embeddings = len(dictionary)
+ self.padding_idx = dictionary.pad()
+ if pretrained_embed is None:
+ self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
+ else:
+ self.embed_tokens = pretrained_embed
+ if fixed_embeddings:
+ self.embed_tokens.weight.requires_grad = False
+
+ self.lstm = LSTM(
+ input_size=embed_dim,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ dropout=self.dropout_out if num_layers > 1 else 0.0,
+ bidirectional=bidirectional,
+ )
+ self.left_pad = left_pad
+ self.padding_value = padding_value
+
+ self.output_units = hidden_size
+ if bidirectional:
+ self.output_units *= 2
+
+ def forward(self, src_tokens, src_lengths, dataset_name):
+ if self.left_pad:
+ # convert left-padding to right-padding
+ src_tokens = utils.convert_padding_direction(
+ src_tokens,
+ self.padding_idx,
+ left_to_right=True,
+ )
+
+ bsz, seqlen = src_tokens.size()
+
+ # embed tokens
+ x = self.embed_tokens(src_tokens)
+ x = F.dropout(x, p=self.dropout_in, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ # pack embedded source tokens into a PackedSequence
+ try:
+ packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
+ except BaseException:
+ raise Exception(f"Packing failed in dataset {dataset_name}")
+
+ # apply LSTM
+ if self.bidirectional:
+ state_size = 2 * self.num_layers, bsz, self.hidden_size
+ else:
+ state_size = self.num_layers, bsz, self.hidden_size
+ h0 = x.data.new(*state_size).zero_()
+ c0 = x.data.new(*state_size).zero_()
+ packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
+
+ # unpack outputs and apply dropout
+ x, _ = nn.utils.rnn.pad_packed_sequence(
+ packed_outs, padding_value=self.padding_value
+ )
+ x = F.dropout(x, p=self.dropout_out, training=self.training)
+ assert list(x.size()) == [seqlen, bsz, self.output_units]
+
+ if self.bidirectional:
+
+ def combine_bidir(outs):
+ return torch.cat(
+ [
+ torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(
+ 1, bsz, self.output_units
+ )
+ for i in range(self.num_layers)
+ ],
+ dim=0,
+ )
+
+ final_hiddens = combine_bidir(final_hiddens)
+ final_cells = combine_bidir(final_cells)
+
+ encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
+
+ # Set padded outputs to -inf so they are not selected by max-pooling
+ padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
+ if padding_mask.any():
+ x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
+
+ # Build the sentence embedding by max-pooling over the encoder outputs
+ sentemb = x.max(dim=0)[0]
+
+ return {
+ "sentemb": sentemb,
+ "encoder_out": (x, final_hiddens, final_cells),
+ "encoder_padding_mask": encoder_padding_mask
+ if encoder_padding_mask.any()
+ else None,
+ }
+
+ def reorder_encoder_out(self, encoder_out_dict, new_order):
+ encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select(
+ 0, new_order
+ )
+ encoder_out_dict["encoder_out"] = tuple(
+ eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"]
+ )
+ if encoder_out_dict["encoder_padding_mask"] is not None:
+ encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[
+ "encoder_padding_mask"
+ ].index_select(1, new_order)
+ return encoder_out_dict
+
+ def max_positions(self):
+ """Maximum input length supported by the encoder."""
+ return int(1e5) # an arbitrary large number
+
+
+class LSTMDecoder(FairseqIncrementalDecoder):
+ """LSTM decoder."""
+
+ def __init__(
+ self,
+ dictionary,
+ embed_dim=512,
+ hidden_size=512,
+ out_embed_dim=512,
+ num_layers=1,
+ dropout_in=0.1,
+ dropout_out=0.1,
+ zero_init=False,
+ encoder_embed_dim=512,
+ encoder_output_units=512,
+ pretrained_embed=None,
+ num_langs=1,
+ lang_embed_dim=0,
+ ):
+ super().__init__(dictionary)
+ self.dropout_in = dropout_in
+ self.dropout_out = dropout_out
+ self.hidden_size = hidden_size
+
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+ if pretrained_embed is None:
+ self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
+ else:
+ self.embed_tokens = pretrained_embed
+
+ self.layers = nn.ModuleList(
+ [
+ LSTMCell(
+ input_size=encoder_output_units + embed_dim + lang_embed_dim
+ if layer == 0
+ else hidden_size,
+ hidden_size=hidden_size,
+ )
+ for layer in range(num_layers)
+ ]
+ )
+ if hidden_size != out_embed_dim:
+ self.additional_fc = Linear(hidden_size, out_embed_dim)
+ self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
+
+ if zero_init:
+ self.sentemb2init = None
+ else:
+ self.sentemb2init = Linear(
+ encoder_output_units, 2 * num_layers * hidden_size
+ )
+
+ if lang_embed_dim == 0:
+ self.embed_lang = None
+ else:
+ self.embed_lang = nn.Embedding(num_langs, lang_embed_dim)
+ nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
+
+ def forward(
+ self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0
+ ):
+ sentemb = encoder_out_dict["sentemb"]
+ encoder_out = encoder_out_dict["encoder_out"]
+
+ if incremental_state is not None:
+ prev_output_tokens = prev_output_tokens[:, -1:]
+ bsz, seqlen = prev_output_tokens.size()
+
+ # get outputs from encoder
+ encoder_outs, _, _ = encoder_out[:3]
+ srclen = encoder_outs.size(0)
+
+ # embed tokens
+ x = self.embed_tokens(prev_output_tokens)
+ x = F.dropout(x, p=self.dropout_in, training=self.training)
+
+ # embed language identifier
+ if self.embed_lang is not None:
+ lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
+ langemb = self.embed_lang(lang_ids)
+ # TODO Should we dropout here???
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ # initialize previous states (or get from cache during incremental generation)
+ cached_state = utils.get_incremental_state(
+ self, incremental_state, "cached_state"
+ )
+ if cached_state is not None:
+ prev_hiddens, prev_cells, input_feed = cached_state
+ else:
+ num_layers = len(self.layers)
+ if self.sentemb2init is None:
+ prev_hiddens = [
+ x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
+ ]
+ prev_cells = [
+ x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
+ ]
+ else:
+ init = self.sentemb2init(sentemb)
+ prev_hiddens = [
+ init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size]
+ for i in range(num_layers)
+ ]
+ prev_cells = [
+ init[
+ :,
+ (2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size,
+ ]
+ for i in range(num_layers)
+ ]
+ input_feed = x.data.new(bsz, self.hidden_size).zero_()
+
+ attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
+ outs = []
+ for j in range(seqlen):
+ if self.embed_lang is None:
+ input = torch.cat((x[j, :, :], sentemb), dim=1)
+ else:
+ input = torch.cat((x[j, :, :], sentemb, langemb), dim=1)
+
+ for i, rnn in enumerate(self.layers):
+ # recurrent cell
+ hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
+
+ # hidden state becomes the input to the next layer
+ input = F.dropout(hidden, p=self.dropout_out, training=self.training)
+
+ # save state for next time step
+ prev_hiddens[i] = hidden
+ prev_cells[i] = cell
+
+ out = hidden
+ out = F.dropout(out, p=self.dropout_out, training=self.training)
+
+ # input feeding
+ input_feed = out
+
+ # save final output
+ outs.append(out)
+
+ # cache previous states (no-op except during incremental generation)
+ utils.set_incremental_state(
+ self,
+ incremental_state,
+ "cached_state",
+ (prev_hiddens, prev_cells, input_feed),
+ )
+
+ # collect outputs across time steps
+ x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(1, 0)
+
+ # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
+ attn_scores = attn_scores.transpose(0, 2)
+
+ # project back to size of vocabulary
+ if hasattr(self, "additional_fc"):
+ x = self.additional_fc(x)
+ x = F.dropout(x, p=self.dropout_out, training=self.training)
+ x = self.fc_out(x)
+
+ return x, attn_scores
+
+ def reorder_incremental_state(self, incremental_state, new_order):
+ super().reorder_incremental_state(incremental_state, new_order)
+ cached_state = utils.get_incremental_state(
+ self, incremental_state, "cached_state"
+ )
+ if cached_state is None:
+ return
+
+ def reorder_state(state):
+ if isinstance(state, list):
+ return [reorder_state(state_i) for state_i in state]
+ return state.index_select(0, new_order)
+
+ new_state = tuple(map(reorder_state, cached_state))
+ utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
+
+ def max_positions(self):
+ """Maximum output length supported by the decoder."""
+ return int(1e5) # an arbitrary large number
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.uniform_(m.weight, -0.1, 0.1)
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+def LSTM(input_size, hidden_size, **kwargs):
+ m = nn.LSTM(input_size, hidden_size, **kwargs)
+ for name, param in m.named_parameters():
+ if "weight" in name or "bias" in name:
+ param.data.uniform_(-0.1, 0.1)
+ return m
+
+
+def LSTMCell(input_size, hidden_size, **kwargs):
+ m = nn.LSTMCell(input_size, hidden_size, **kwargs)
+ for name, param in m.named_parameters():
+ if "weight" in name or "bias" in name:
+ param.data.uniform_(-0.1, 0.1)
+ return m
+
+
+def Linear(in_features, out_features, bias=True, dropout=0):
+ """Weight-normalized Linear layer (input: N x T x C)"""
+ m = nn.Linear(in_features, out_features, bias=bias)
+ m.weight.data.uniform_(-0.1, 0.1)
+ if bias:
+ m.bias.data.uniform_(-0.1, 0.1)
+ return m
+
+
+@register_model_architecture("laser_lstm", "laser_lstm")
+def base_architecture(args):
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
+ args.encoder_hidden_size = getattr(
+ args, "encoder_hidden_size", args.encoder_embed_dim
+ )
+ args.encoder_layers = getattr(args, "encoder_layers", 1)
+ args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False)
+ args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout)
+ args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout)
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
+ args.decoder_hidden_size = getattr(
+ args, "decoder_hidden_size", args.decoder_embed_dim
+ )
+ args.decoder_layers = getattr(args, "decoder_layers", 1)
+ args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
+ args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout)
+ args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
+ args.decoder_zero_init = getattr(args, "decoder_zero_init", "0")
+ args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
+ args.fixed_embeddings = getattr(args, "fixed_embeddings", False)
diff --git a/examples/laser/laser_src/laser_task.py b/examples/laser/laser_src/laser_task.py
new file mode 100644
index 0000000000..e4152fde68
--- /dev/null
+++ b/examples/laser/laser_src/laser_task.py
@@ -0,0 +1,331 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from collections import OrderedDict, defaultdict
+import json
+import os
+import logging
+from argparse import ArgumentError
+
+from fairseq import options, models
+from fairseq.data import (
+ data_utils,
+ Dictionary,
+ LanguagePairDataset,
+ IndexedDataset,
+ FairseqDataset,
+)
+from .multitask_data_utils import (
+ MultitaskDatasetWrapper,
+ MultidatasetEpochBatchIterator,
+)
+
+
+from fairseq.tasks import LegacyFairseqTask, register_task
+
+logger = logging.getLogger(__name__)
+
+
+@register_task("laser")
+class LaserTask(LegacyFairseqTask):
+ @staticmethod
+ def add_args(parser):
+ """Add task-specific arguments to the parser."""
+ parser.add_argument(
+ "configfile", metavar="PATH", help="dataset configuration file in json"
+ )
+ parser.add_argument(
+ "--weighting-alpha",
+ type=float,
+ default=None,
+ help="alpha for automatic weighting",
+ )
+ parser.add_argument(
+ "--raw-text", action="store_true", help="load raw text dataset"
+ )
+ parser.add_argument(
+ "--left-pad-source",
+ default="True",
+ type=str,
+ metavar="BOOL",
+ help="pad the source on the left (default: True)",
+ )
+ parser.add_argument(
+ "--left-pad-target",
+ default="False",
+ type=str,
+ metavar="BOOL",
+ help="pad the target on the left (default: False)",
+ )
+ try:
+ parser.add_argument(
+ "--max-source-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the source sequence",
+ )
+ parser.add_argument(
+ "--max-target-positions",
+ default=1024,
+ type=int,
+ metavar="N",
+ help="max number of tokens in the target sequence",
+ )
+ except ArgumentError:
+ # this might have already been defined. Once we transition this to hydra it should be fine to add it here.
+ pass
+
+ def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
+ super().__init__(args)
+ self.config = config
+ self.src_dictionary = src_dictionary
+ self.tgt_dictionary = tgt_dictionary
+ self.num_tasks = num_tasks
+
+ @classmethod
+ def setup_task(cls, args, **kwargs):
+ with open(args.configfile, "r") as f:
+ config = json.load(f)
+ num_tasks = max(dataset["id"] for dataset in config["train"]) + 1
+
+ args.left_pad_source = options.eval_bool(args.left_pad_source)
+ args.left_pad_target = options.eval_bool(args.left_pad_target)
+
+ src_dictionary = Dictionary.load(config["src_vocab"])
+ tgt_dictionary = Dictionary.load(config["tgt_vocab"])
+
+ logger.info(
+ "| src Dictionary {} : {} types".format(
+ config["src_vocab"], len(src_dictionary)
+ )
+ )
+ logger.info(
+ "| tgt Dictionary {} : {} types".format(
+ config["tgt_vocab"], len(tgt_dictionary)
+ )
+ )
+
+ return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)
+
+ # Experimental overriding for backtranslation
+ def build_model(self, args):
+ model = models.build_model(args, self)
+ return model
+
+ def dataset(self, split):
+ if split not in self.datasets:
+ raise KeyError("Dataset not loaded: " + split)
+ return self.datasets[split]
+
+ def load_dataset(self, split, epoch=1, **kwargs):
+ """Load a dataset split."""
+
+ def indexed_dataset(path, dictionary):
+ if self.args.raw_text:
+ raise Exception("Unable to handle raw text.")
+ dataset = IndexedDataset(path, fix_lua_indexing=True)
+
+ return dataset
+
+ pair_datasets = OrderedDict()
+
+ if split == "valid":
+ self.datasets[split] = pair_datasets
+ return
+
+ if split not in self.config:
+ raise FileNotFoundError(
+ "Dataset not found in config file: {}".format(split)
+ )
+
+ size_by_corpus = defaultdict(int)
+ size_sum = 0
+ size_sum_with_subsampling = 0
+ init_pair_datasets = {}
+
+ for dataset_config in self.config[split]:
+ src_path = os.path.dirname(dataset_config["src"])
+ corpus_name = src_path.split("/")[-2]
+ language_pair_name = src_path.split("/")[-1]
+ pair_datasets_key = corpus_name + "-" + language_pair_name
+
+ logger.info(f"loading... {pair_datasets_key}")
+ if "src" in dataset_config:
+ src_dataset = indexed_dataset(
+ dataset_config["src"], self.src_dictionary
+ )
+ else:
+ src_dataset = None
+
+ if "tgt" in dataset_config:
+ tgt_dataset = indexed_dataset(
+ dataset_config["tgt"], self.tgt_dictionary
+ )
+ else:
+ tgt_dataset = None
+
+ dataset = LanguagePairDataset(
+ src_dataset,
+ src_dataset.sizes,
+ self.src_dictionary,
+ tgt_dataset,
+ tgt_dataset.sizes,
+ self.tgt_dictionary,
+ left_pad_source=self.args.left_pad_source,
+ left_pad_target=self.args.left_pad_target,
+ )
+
+ if pair_datasets_key in init_pair_datasets:
+ logger.warning(
+ f"Ignoring already added {pair_datasets_key}. "
+ f"Consider using `sample` key in order to upsample."
+ )
+ else:
+ init_pair_datasets[pair_datasets_key] = {
+ "dataset": dataset,
+ "sample": dataset_config.get("sample", None),
+ "id": dataset_config.get("id", None),
+ "len": len(dataset),
+ }
+
+ length_sum = 0
+ weighted_freqs_sum = 0
+ freq_per_dataset = {}
+ vmax = 0
+ vmin = 1
+ weighted_freq_per_dataset = {}
+
+ if self.args.weighting_alpha:
+ for key in init_pair_datasets:
+ if init_pair_datasets[key]["sample"] is None:
+ length_sum += len(init_pair_datasets[key]["dataset"])
+
+ for key in init_pair_datasets:
+ if init_pair_datasets[key]["sample"] is None:
+ val = float(init_pair_datasets[key]["len"]) / length_sum
+ freq_per_dataset[key] = val
+ weighted_freqs_sum += val ** self.args.weighting_alpha
+
+ for key in freq_per_dataset:
+ val = (
+ freq_per_dataset[key] ** self.args.weighting_alpha
+ / weighted_freqs_sum
+ )
+ vmin = min(vmin, val)
+ vmax = max(vmax, val)
+ weighted_freq_per_dataset[key] = val
+
+ for pair_datasets_key in init_pair_datasets:
+ dataset_config = init_pair_datasets[pair_datasets_key]
+ dataset = dataset_config["dataset"]
+ sample = dataset_config["sample"]
+ if sample is None:
+ sample = 1.0
+
+ if pair_datasets_key in weighted_freq_per_dataset:
+ w = vmax / weighted_freq_per_dataset[pair_datasets_key]
+ sample = w
+
+ sample = round(sample)
+
+ initial_sample = sample
+ initial_pair_datasets_key = pair_datasets_key
+
+ while sample >= 1.0:
+ assert (
+ pair_datasets_key not in pair_datasets
+ ), f"{pair_datasets_key} already in"
+ size_sum_with_subsampling += len(dataset)
+ pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
+ dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
+ )
+ size_sum += len(dataset)
+ sample -= 1.0
+ pair_datasets_key += "-up"
+
+ assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"
+
+ logger.info(
+ f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
+ )
+ size_by_corpus[corpus_name] += len(dataset)
+
+ self.datasets[split] = pair_datasets
+ logger.info(
+ f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
+ )
+
+ @property
+ def source_dictionary(self):
+ return self.src_dictionary
+
+ @property
+ def target_dictionary(self):
+ return self.tgt_dictionary
+
+ def get_batch_iterator(
+ self,
+ dataset,
+ max_tokens=None,
+ max_sentences=None,
+ max_positions=None,
+ ignore_invalid_inputs=False,
+ required_batch_size_multiple=1,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ data_buffer_size=0,
+ disable_iterator_cache=False,
+ ):
+
+ assert isinstance(dataset, OrderedDict)
+ assert len(dataset)
+ assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
+
+ # initialize the dataset with the correct starting epoch
+ for _, dt in dataset.items():
+ dt.set_epoch(epoch)
+
+ indices = OrderedDict()
+ batch_sampler = OrderedDict()
+
+ with data_utils.numpy_seed(seed + epoch):
+ for key, dt in dataset.items():
+ logger.info(f"\t ordered_indices {key}")
+ indices[key] = dt.ordered_indices()
+
+ # filter examples that are too large
+ if max_positions is not None:
+ for key, dt in dataset.items():
+ logger.info(f"\t filter_by_size {key}")
+ indices[key], ignored = dt.filter_indices_by_size(
+ indices[key], max_positions
+ )
+
+ for key, dt in dataset.items():
+ logger.info(f"\t batch_by_size {key}")
+ batch_sampler[key] = data_utils.batch_by_size(
+ indices[key],
+ dt.num_tokens,
+ max_tokens=max_tokens,
+ max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ )
+
+ epoch_iter = MultidatasetEpochBatchIterator(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=num_workers,
+ epoch=epoch,
+ )
+
+ return epoch_iter
diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py
new file mode 100644
index 0000000000..0be030994f
--- /dev/null
+++ b/examples/laser/laser_src/laser_transformer.py
@@ -0,0 +1,354 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+from typing import Any, Dict, List, Optional
+from torch import Tensor
+
+import torch
+import torch.nn as nn
+
+from fairseq.models import (
+ FairseqEncoderDecoderModel,
+ register_model,
+ register_model_architecture,
+)
+from fairseq.models.transformer import (
+ base_architecture,
+ Embedding,
+ TransformerModel,
+ TransformerEncoder,
+ TransformerDecoder,
+)
+from fairseq.modules import (
+ TransformerDecoderLayer,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@register_model("laser_transformer")
+class LaserTransformerModel(FairseqEncoderDecoderModel):
+ """Train Transformer for LASER task
+
+ Requires --task laser
+ """
+
+ def __init__(self, encoder, decoder):
+ super().__init__(encoder, decoder)
+
+ def forward(
+ self,
+ src_tokens,
+ src_lengths,
+ prev_output_tokens=None,
+ tgt_tokens=None,
+ tgt_lengths=None,
+ target_language_id=-1,
+ dataset_name="",
+ ):
+ laser_encoder_out = self.encoder(src_tokens, src_lengths)
+ return self.decoder(
+ prev_output_tokens, laser_encoder_out, lang_id=target_language_id
+ )
+
+ @staticmethod
+ def add_args(parser):
+ """Add model-specific arguments to the parser."""
+ TransformerModel.add_args(parser)
+ parser.add_argument(
+ "--decoder-lang-embed-dim",
+ type=int,
+ metavar="N",
+ help="decoder language embedding dimension",
+ )
+
+ @classmethod
+ def build_model(cls, args, task):
+ base_laser_transformer_architecture(args)
+
+ num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
+
+ def load_embed_tokens(dictionary, embed_dim):
+ num_embeddings = len(dictionary)
+ padding_idx = dictionary.pad()
+
+ return Embedding(num_embeddings, embed_dim, padding_idx)
+
+ encoder_embed_tokens = load_embed_tokens(
+ task.source_dictionary, args.encoder_embed_dim
+ )
+ decoder_embed_tokens = load_embed_tokens(
+ task.target_dictionary, args.decoder_embed_dim
+ )
+ num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
+
+ encoder = LaserTransformerEncoder(
+ args, task.source_dictionary, encoder_embed_tokens
+ )
+
+ decoder = LaserTransformerDecoder(
+ args,
+ task.target_dictionary,
+ decoder_embed_tokens,
+ num_langs=num_langs,
+ lang_embed_dim=args.decoder_lang_embed_dim,
+ )
+
+ return cls(encoder, decoder)
+
+
+class LaserTransformerEncoder(TransformerEncoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, src_tokens, *args, **kwargs):
+ encoder_out = super().forward(src_tokens, *args, **kwargs)
+
+ x = encoder_out["encoder_out"][0] # T x B x C
+ padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
+
+ if padding_mask.any():
+ x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
+
+ # Build the sentence embedding by max-pooling over the encoder outputs
+ sentemb = x.max(dim=0)[0]
+
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
+ # `foward` so we use a dictionary instead.
+ # TorchScript does not support mixed values so the values are all lists.
+ # The empty list is equivalent to None.
+ return {"sentemb": [sentemb]} # B x C
+
+ @torch.jit.export
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
+ """
+ Same as the one in transformer.py, with new_sentemb
+ """
+ if len(encoder_out["sentemb"]) == 0:
+ new_sentemb = []
+ else:
+ new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)]
+
+ return {
+ "sentemb": new_sentemb, # B x C
+ }
+
+
+class LaserTransformerDecoder(TransformerDecoder):
+ def __init__(self, args, dictionary, *kargs, **kwargs):
+ self.num_langs = kwargs.get("num_langs", 1)
+ self.lang_embed_dim = kwargs.get("lang_embed_dim", 0)
+ kwargs.pop("num_langs", None)
+ kwargs.pop("lang_embed_dim", None)
+
+ super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True)
+
+ if self.lang_embed_dim == 0:
+ self.embed_lang = None
+ else:
+ self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim)
+ nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
+
+ if self.output_projection is not None:
+ laser_output_embed_dim = (
+ self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
+ )
+ self.output_projection = nn.Linear(
+ laser_output_embed_dim, len(dictionary), bias=False
+ )
+ nn.init.normal_(
+ self.output_projection.weight,
+ mean=0,
+ std=laser_output_embed_dim ** -0.5,
+ )
+
+ def build_decoder_layer(self, args, no_encoder_attn=False):
+ decoder_embed_dim = args.decoder_embed_dim
+ args.decoder_embed_dim = (
+ decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
+ )
+ res = TransformerDecoderLayer(args, no_encoder_attn=True)
+ args.decoder_embed_dim = decoder_embed_dim
+
+ return res
+
+ def extract_features(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]],
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ full_context_alignment: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ lang_id: Optional[int] = None,
+ ):
+ """
+ Similar to *forward* but only return features.
+
+ Includes several features from "Jointly Learning to Align and
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
+
+ Args:
+ full_context_alignment (bool, optional): don't apply
+ auto-regressive mask to self-attention (default: False).
+ alignment_layer (int, optional): return mean alignment over
+ heads at this layer (default: last layer).
+ alignment_heads (int, optional): only average alignment over
+ this many heads (default: all heads).
+
+ Returns:
+ tuple:
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+ - a dictionary with any model-specific outputs
+ """
+ if alignment_layer is None:
+ alignment_layer = self.num_layers - 1
+
+ # embed positions
+ positions = (
+ self.embed_positions(
+ prev_output_tokens, incremental_state=incremental_state
+ )
+ if self.embed_positions is not None
+ else None
+ )
+
+ if incremental_state is not None:
+ prev_output_tokens = prev_output_tokens[:, -1:]
+ if positions is not None:
+ positions = positions[:, -1:]
+
+ bsz, seqlen = prev_output_tokens.size()
+
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+ if self.quant_noise is not None:
+ x = self.quant_noise(x)
+
+ if self.project_in_dim is not None:
+ x = self.project_in_dim(x)
+
+ if positions is not None:
+ x += positions
+
+ if self.layernorm_embedding is not None:
+ x = self.layernorm_embedding(x)
+
+ x = self.dropout_module(x)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ if self.embed_lang is not None:
+ lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
+ langemb = self.embed_lang(lang_ids)
+ langemb = langemb.unsqueeze(0)
+ repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * (
+ len(langemb.shape) - 1
+ )
+ x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1)
+
+ sentemb = encoder_out["sentemb"][0]
+ sentemb = sentemb.unsqueeze(0)
+
+ repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1)
+ x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1)
+
+ self_attn_padding_mask: Optional[Tensor] = None
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
+
+ # decoder layers
+ attn: Optional[Tensor] = None
+ inner_states: List[Optional[Tensor]] = [x]
+ for idx, layer in enumerate(self.layers):
+ if incremental_state is None and not full_context_alignment:
+ self_attn_mask = self.buffered_future_mask(x)
+ else:
+ self_attn_mask = None
+
+ x, layer_attn, _ = layer(
+ x,
+ None,
+ None,
+ incremental_state,
+ self_attn_mask=self_attn_mask,
+ self_attn_padding_mask=self_attn_padding_mask,
+ need_attn=bool((idx == alignment_layer)),
+ need_head_weights=bool((idx == alignment_layer)),
+ )
+ inner_states.append(x)
+ if layer_attn is not None and idx == alignment_layer:
+ attn = layer_attn.float().to(x)
+
+ if attn is not None:
+ if alignment_heads is not None:
+ attn = attn[:alignment_heads]
+
+ # average probabilities over heads
+ attn = attn.mean(dim=0)
+
+ if self.layer_norm is not None:
+ x = self.layer_norm(x)
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ if self.project_out_dim is not None:
+ x = self.project_out_dim(x)
+
+ return x, {"attn": [attn], "inner_states": inner_states}
+
+ def forward(
+ self,
+ prev_output_tokens,
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ features_only: bool = False,
+ alignment_layer: Optional[int] = None,
+ alignment_heads: Optional[int] = None,
+ src_lengths: Optional[Any] = None,
+ return_all_hiddens: bool = False,
+ lang_id: Optional[int] = None,
+ ):
+ """
+ Args:
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
+ `(batch, tgt_len)`, for teacher forcing
+ encoder_out (optional): output from the encoder, used for
+ encoder-side attention
+ incremental_state (dict): dictionary used for storing state during
+ :ref:`Incremental decoding`
+ features_only (bool, optional): only return features without
+ applying output layer (default: False).
+
+ Returns:
+ tuple:
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
+ - a dictionary with any model-specific outputs
+ """
+
+ assert lang_id is not None
+
+ x, extra = self.extract_features(
+ prev_output_tokens,
+ encoder_out=encoder_out,
+ incremental_state=incremental_state,
+ alignment_layer=alignment_layer,
+ alignment_heads=alignment_heads,
+ lang_id=lang_id,
+ )
+ if not features_only:
+ x = self.output_layer(x)
+ return x, extra
+
+
+@register_model_architecture("laser_transformer", "laser_transformer")
+def base_laser_transformer_architecture(args):
+ base_architecture(args)
+ args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
diff --git a/examples/laser/laser_src/multitask_data_utils.py b/examples/laser/laser_src/multitask_data_utils.py
new file mode 100644
index 0000000000..b05caea267
--- /dev/null
+++ b/examples/laser/laser_src/multitask_data_utils.py
@@ -0,0 +1,143 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+
+import numpy as np
+
+from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators
+
+
+class MultiItr(object):
+ def __init__(self, itr):
+ self.itr = itr
+ self._counts = [0 for x in itr]
+
+ def __len__(self):
+ return sum(len(itr) for itr in self.itr)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
+ idx = ratios.index(min(ratios))
+ self._counts[idx] += 1
+ return next(self.itr[idx])
+
+
+class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
+ """A wrapper around multiple epoch batch iterators."""
+
+ def __init__(
+ self,
+ dataset,
+ batch_sampler,
+ seed=1,
+ num_shards=1,
+ shard_id=0,
+ num_workers=0,
+ epoch=1,
+ ):
+
+ assert isinstance(dataset, OrderedDict)
+ assert len(dataset)
+ assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
+
+ self.iterators = []
+
+ self.epoch = epoch
+ for key, dt in dataset.items():
+ epoch_iter = iterators.EpochBatchIterator(
+ dataset=dt,
+ collate_fn=dt.collater,
+ batch_sampler=batch_sampler[key],
+ seed=seed,
+ num_shards=num_shards,
+ shard_id=shard_id,
+ num_workers=0,
+ epoch=epoch,
+ )
+ self.iterators.append(epoch_iter)
+
+ def __len__(self):
+ return sum(len(itr) for itr in self.iterators)
+
+ def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
+ # `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
+ return MultiItr(
+ [
+ itr.next_epoch_itr(
+ shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
+ )
+ for itr in self.iterators
+ ]
+ )
+
+ def end_of_epoch(self):
+ return all(itr.end_of_epoch() for itr in self.iterators)
+
+ @property
+ def next_epoch_idx(self):
+ """Return the epoch index after *next_epoch_itr* is called."""
+
+ epochs = [itr.next_epoch_idx for itr in self.iterators]
+ self.epoch = epochs[0]
+ assert all(epoch == self.epoch for epoch in epochs)
+
+ return self.epoch
+
+ @property
+ def iterations_in_epoch(self):
+ return sum(itr.iterations_in_epoch for itr in self.iterators)
+
+ def state_dict(self):
+ return {
+ "iterators": [it.state_dict() for it in self.iterators],
+ "epoch": self.epoch,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.epoch = state_dict["epoch"]
+ for it, d in zip(self.iterators, state_dict["iterators"]):
+ it.load_state_dict(d)
+
+
+class MultitaskDatasetWrapper(BaseWrapperDataset):
+ """A wrapper for a multitask dataset."""
+
+ def __init__(self, dataset, target_language_id, sample=1.0, name=""):
+ super().__init__(dataset)
+ self.target_language_id = target_language_id
+ self.sample = sample
+ self.name = name
+
+ def collater(self, *args, **kwargs):
+ ans = self.dataset.collater(*args, **kwargs)
+ if "net_input" in ans:
+ ans["net_input"]["target_language_id"] = self.target_language_id
+ ans["net_input"]["dataset_name"] = self.name
+ return ans
+
+ def num_tokens(self, *args, **kwargs):
+ return self.dataset.num_tokens(*args, **kwargs)
+
+ def ordered_indices(self, *args, **kwargs):
+ indices = self.dataset.ordered_indices(*args, **kwargs)
+ # Hacky solution for sampling
+ size = int(self.sample * indices.shape[0])
+
+ return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))
+
+ def size(self, index: int):
+ return self.dataset.size(index)
+
+ @property
+ def supports_prefetch(self):
+ """Whether this dataset supports prefetching."""
+ return getattr(self.dataset, "supports_prefetch", False)
+
+ def prefetch(self, indices):
+ return self.dataset.prefetch(indices)
diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md
new file mode 100644
index 0000000000..7774c33305
--- /dev/null
+++ b/examples/latent_depth/README.md
@@ -0,0 +1,77 @@
+# Deep Transformers with Latent Depth (Li et al., 2020)
+
+[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102).
+
+## Introduction
+
+We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
+
+## Training a multilingual model with latent depth
+
+Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
+```bash
+lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
+databin_dir= Q: Where would I not want a fox? A: hen house `
- question = 'Q: ' + question
+ question = "Q: " + question
question_toks = binarize(question, append_bos=True)
- for i, choice in enumerate(example['question']['choices']):
- src = 'A: ' + choice['text']
+ for i, choice in enumerate(example["question"]["choices"]):
+ src = "A: " + choice["text"]
src_bin = torch.cat([question_toks, binarize(src)])
src_tokens[i].append(src_bin)
src_lengths[i].append(len(src_bin))
- assert all(len(src_tokens[0]) == len(src_tokens[i]) for i in range(self.args.num_classes))
+ assert all(
+ len(src_tokens[0]) == len(src_tokens[i])
+ for i in range(self.args.num_classes)
+ )
assert len(src_tokens[0]) == len(src_lengths[0])
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
@@ -118,24 +131,26 @@ def binarize(s, append_bos=False):
src_lengths[i] = ListDataset(src_lengths[i])
dataset = {
- 'id': IdDataset(),
- 'nsentences': NumSamplesDataset(),
- 'ntokens': NumelDataset(src_tokens[0], reduce=True),
+ "id": IdDataset(),
+ "nsentences": NumSamplesDataset(),
+ "ntokens": NumelDataset(src_tokens[0], reduce=True),
}
for i in range(self.args.num_classes):
- dataset.update({
- 'net_input{}'.format(i + 1): {
- 'src_tokens': RightPadDataset(
- src_tokens[i],
- pad_idx=self.source_dictionary.pad(),
- ),
- 'src_lengths': src_lengths[i],
+ dataset.update(
+ {
+ "net_input{}".format(i + 1): {
+ "src_tokens": RightPadDataset(
+ src_tokens[i],
+ pad_idx=self.source_dictionary.pad(),
+ ),
+ "src_lengths": src_lengths[i],
+ }
}
- })
+ )
if len(labels) > 0:
- dataset.update({'target': RawLabelDataset(labels)})
+ dataset.update({"target": RawLabelDataset(labels)})
dataset = NestedDictionaryDataset(
dataset,
@@ -149,17 +164,18 @@ def binarize(s, append_bos=False):
sort_order=[np.random.permutation(len(dataset))],
)
- print('| Loaded {} with {} samples'.format(split, len(dataset)))
+ print("| Loaded {} with {} samples".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]
def build_model(self, args):
from fairseq import models
+
model = models.build_model(args, self)
model.register_classification_head(
- 'sentence_classification_head',
+ "sentence_classification_head",
num_classes=1,
)
diff --git a/examples/roberta/config/finetuning/cola.yaml b/examples/roberta/config/finetuning/cola.yaml
new file mode 100644
index 0000000000..ac76611201
--- /dev/null
+++ b/examples/roberta/config/finetuning/cola.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 320
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 5336
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/mnli.yaml b/examples/roberta/config/finetuning/mnli.yaml
new file mode 100644
index 0000000000..5be10c362f
--- /dev/null
+++ b/examples/roberta/config/finetuning/mnli.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 3
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 7432
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 123873
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/mrpc.yaml b/examples/roberta/config/finetuning/mrpc.yaml
new file mode 100644
index 0000000000..aa8b7db393
--- /dev/null
+++ b/examples/roberta/config/finetuning/mrpc.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 137
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 2296
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/qnli.yaml b/examples/roberta/config/finetuning/qnli.yaml
new file mode 100644
index 0000000000..b4595b090e
--- /dev/null
+++ b/examples/roberta/config/finetuning/qnli.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1986
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 33112
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/qqp.yaml b/examples/roberta/config/finetuning/qqp.yaml
new file mode 100644
index 0000000000..5a2b2ed743
--- /dev/null
+++ b/examples/roberta/config/finetuning/qqp.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 28318
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 113272
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/rte.yaml b/examples/roberta/config/finetuning/rte.yaml
new file mode 100644
index 0000000000..7318465011
--- /dev/null
+++ b/examples/roberta/config/finetuning/rte.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 122
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 2036
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/sst_2.yaml b/examples/roberta/config/finetuning/sst_2.yaml
new file mode 100644
index 0000000000..a93ad2f22c
--- /dev/null
+++ b/examples/roberta/config/finetuning/sst_2.yaml
@@ -0,0 +1,59 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 2
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ best_checkpoint_metric: accuracy
+ maximize_best_checkpoint_metric: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+
+dataset:
+ batch_size: 32
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 1256
+
+optimization:
+ clip_norm: 0.0
+ lr: [1e-05]
+ max_update: 20935
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/finetuning/sts_b.yaml b/examples/roberta/config/finetuning/sts_b.yaml
new file mode 100644
index 0000000000..2d495221ad
--- /dev/null
+++ b/examples/roberta/config/finetuning/sts_b.yaml
@@ -0,0 +1,58 @@
+# @package _group_
+
+common:
+ fp16: true
+ fp16_init_scale: 4
+ threshold_loss_scale: 1
+ fp16_scale_window: 128
+ log_format: json
+ log_interval: 200
+
+task:
+ _name: sentence_prediction
+ data: ???
+ init_token: 0
+ separator_token: 2
+ num_classes: 1
+ max_positions: 512
+
+checkpoint:
+ restore_file: ???
+ reset_optimizer: true
+ reset_dataloader: true
+ reset_meters: true
+ no_epoch_checkpoints: true
+
+distributed_training:
+ find_unused_parameters: true
+ distributed_world_size: 1
+
+criterion:
+ _name: sentence_prediction
+ regression_target: true
+
+dataset:
+ batch_size: 16
+ required_batch_size_multiple: 1
+ max_tokens: 4400
+
+optimizer:
+ _name: adam
+ weight_decay: 0.1
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 214
+
+optimization:
+ clip_norm: 0.0
+ lr: [2e-05]
+ max_update: 3598
+ max_epoch: 10
+
+model:
+ _name: roberta
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/config/pretraining/base.yaml b/examples/roberta/config/pretraining/base.yaml
new file mode 100644
index 0000000000..97829908f7
--- /dev/null
+++ b/examples/roberta/config/pretraining/base.yaml
@@ -0,0 +1,42 @@
+# @package _group_
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ no_epoch_checkpoints: true
+
+task:
+ _name: masked_lm
+ data: ???
+ sample_break_mode: complete
+ tokens_per_sample: 512
+
+criterion: masked_lm
+
+dataset:
+ batch_size: 16
+ ignore_unused_valid_subsets: true
+
+optimizer:
+ _name: adam
+ weight_decay: 0.01
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 10000
+
+optimization:
+ clip_norm: 0
+ lr: [0.0005]
+ max_update: 125000
+ update_freq: [16]
+
+model:
+ _name: roberta
+ max_positions: 512
+ dropout: 0.1
+ attention_dropout: 0.1
diff --git a/examples/roberta/multiprocessing_bpe_encoder.py b/examples/roberta/multiprocessing_bpe_encoder.py
index f0240c210f..43fe0451bf 100644
--- a/examples/roberta/multiprocessing_bpe_encoder.py
+++ b/examples/roberta/multiprocessing_bpe_encoder.py
@@ -8,7 +8,6 @@
import argparse
import contextlib
import sys
-
from collections import Counter
from multiprocessing import Pool
@@ -26,23 +25,23 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder-json",
- help='path to encoder.json',
+ help="path to encoder.json",
)
parser.add_argument(
"--vocab-bpe",
type=str,
- help='path to vocab.bpe',
+ help="path to vocab.bpe",
)
parser.add_argument(
"--inputs",
nargs="+",
- default=['-'],
+ default=["-"],
help="input files to filter/encode",
)
parser.add_argument(
"--outputs",
nargs="+",
- default=['-'],
+ default=["-"],
help="path to save encoded outputs",
)
parser.add_argument(
@@ -53,18 +52,21 @@ def main():
parser.add_argument("--workers", type=int, default=20)
args = parser.parse_args()
- assert len(args.inputs) == len(args.outputs), \
- "number of input and output paths should match"
+ assert len(args.inputs) == len(
+ args.outputs
+ ), "number of input and output paths should match"
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
- if input != "-" else sys.stdin
+ if input != "-"
+ else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
- if output != "-" else sys.stdout
+ if output != "-"
+ else sys.stdout
for output in args.outputs
]
@@ -87,7 +89,6 @@ def main():
class MultiprocessingEncoder(object):
-
def __init__(self, args):
self.args = args
diff --git a/examples/roberta/preprocess_RACE.py b/examples/roberta/preprocess_RACE.py
index f6f606a389..cdd6607271 100644
--- a/examples/roberta/preprocess_RACE.py
+++ b/examples/roberta/preprocess_RACE.py
@@ -25,7 +25,7 @@ def get_examples(data_dir, set_type):
examples = []
levels = ["middle", "high"]
- set_type_c = set_type.split('-')
+ set_type_c = set_type.split("-")
if len(set_type_c) == 2:
levels = [set_type_c[1]]
set_type = set_type_c[0]
@@ -33,13 +33,13 @@ def get_examples(data_dir, set_type):
cur_dir = os.path.join(data_dir, set_type, level)
for filename in os.listdir(cur_dir):
cur_path = os.path.join(cur_dir, filename)
- with open(cur_path, 'r') as f:
+ with open(cur_path, "r") as f:
cur_data = json.load(f)
answers = cur_data["answers"]
options = cur_data["options"]
questions = cur_data["questions"]
context = cur_data["article"].replace("\n", " ")
- context = re.sub(r'\s+', ' ', context)
+ context = re.sub(r"\s+", " ", context)
for i in range(len(answers)):
label = ord(answers[i]) - ord("A")
qa_list = []
@@ -50,7 +50,7 @@ def get_examples(data_dir, set_type):
qa_cat = question.replace("_", option)
else:
qa_cat = " ".join([question, option])
- qa_cat = re.sub(r'\s+', ' ', qa_cat)
+ qa_cat = re.sub(r"\s+", " ", qa_cat)
qa_list.append(qa_cat)
examples.append(InputExample(context, qa_list, label))
@@ -64,11 +64,11 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dir",
- help='input directory for downloaded RACE dataset',
+ help="input directory for downloaded RACE dataset",
)
parser.add_argument(
"--output-dir",
- help='output directory for extracted data',
+ help="output directory for extracted data",
)
args = parser.parse_args()
@@ -77,17 +77,20 @@ def main():
for set_type in ["train", "dev", "test-middle", "test-high"]:
examples = get_examples(args.input_dir, set_type)
- qa_file_paths = [os.path.join(args.output_dir, set_type + ".input" + str(i + 1)) for i in range(4)]
- qa_files = [open(qa_file_path, 'w') for qa_file_path in qa_file_paths]
+ qa_file_paths = [
+ os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
+ for i in range(4)
+ ]
+ qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
- outf_context = open(outf_context_path, 'w')
- outf_label = open(outf_label_path, 'w')
+ outf_context = open(outf_context_path, "w")
+ outf_label = open(outf_label_path, "w")
for example in examples:
- outf_context.write(example.paragraph + '\n')
+ outf_context.write(example.paragraph + "\n")
for i in range(4):
- qa_files[i].write(example.qa_list[i] + '\n')
- outf_label.write(str(example.label) + '\n')
+ qa_files[i].write(example.qa_list[i] + "\n")
+ outf_label.write(str(example.label) + "\n")
for f in qa_files:
f.close()
@@ -95,5 +98,5 @@ def main():
outf_context.close()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/roberta/wsc/README.md b/examples/roberta/wsc/README.md
index 0d3f62a07f..21a045d999 100644
--- a/examples/roberta/wsc/README.md
+++ b/examples/roberta/wsc/README.md
@@ -51,7 +51,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
- --fp16 --ddp-backend no_c10d \
+ --fp16 --ddp-backend legacy_ddp \
--user-dir $FAIRSEQ_USER_DIR \
--task wsc --criterion wsc --wsc-cross-entropy \
--arch roberta_large --bpe gpt2 --max-positions 512 \
@@ -59,7 +59,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
- --max-sentences $MAX_SENTENCES \
+ --batch-size $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100 \
--seed $SEED
@@ -110,7 +110,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
--valid-subset val \
- --fp16 --ddp-backend no_c10d \
+ --fp16 --ddp-backend legacy_ddp \
--user-dir $FAIRSEQ_USER_DIR \
--task winogrande --criterion winogrande \
--wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \
@@ -119,7 +119,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
--lr-scheduler polynomial_decay --lr $LR \
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
- --max-sentences $MAX_SENTENCES \
+ --batch-size $MAX_SENTENCES \
--max-update $TOTAL_NUM_UPDATES \
--log-format simple --log-interval 100
```
diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py
index dd909ab20c..ed0251fdec 100644
--- a/examples/roberta/wsc/wsc_criterion.py
+++ b/examples/roberta/wsc/wsc_criterion.py
@@ -7,23 +7,21 @@
import torch
import torch.nn.functional as F
-
from fairseq import utils
-from fairseq.data import encoders
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
+from fairseq.data import encoders
-@register_criterion('wsc')
+@register_criterion("wsc")
class WSCCriterion(LegacyFairseqCriterion):
-
def __init__(self, args, task):
super().__init__(args, task)
if self.args.save_predictions is not None:
- self.prediction_h = open(self.args.save_predictions, 'w')
+ self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
- self.bpe = encoders.build_bpe(args)
- self.tokenizer = encoders.build_tokenizer(args)
+ self.bpe = encoders.build_bpe(args.bpe)
+ self.tokenizer = encoders.build_tokenizer(args.tokenizer)
def __del__(self):
if self.prediction_h is not None:
@@ -32,12 +30,16 @@ def __del__(self):
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
- parser.add_argument('--wsc-margin-alpha', type=float, metavar='A', default=1.0)
- parser.add_argument('--wsc-margin-beta', type=float, metavar='B', default=0.0)
- parser.add_argument('--wsc-cross-entropy', action='store_true',
- help='use cross entropy formulation instead of margin loss')
- parser.add_argument('--save-predictions', metavar='FILE',
- help='file to save predictions to')
+ parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
+ parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
+ parser.add_argument(
+ "--wsc-cross-entropy",
+ action="store_true",
+ help="use cross entropy formulation instead of margin loss",
+ )
+ parser.add_argument(
+ "--save-predictions", metavar="FILE", help="file to save predictions to"
+ )
def get_masked_input(self, tokens, mask):
masked_tokens = tokens.clone()
@@ -60,27 +62,26 @@ def get_loss(self, query_lprobs, cand_lprobs):
)
else:
return (
- - query_lprobs
- + self.args.wsc_margin_alpha * (
- cand_lprobs - query_lprobs + self.args.wsc_margin_beta
- ).clamp(min=0)
+ -query_lprobs
+ + self.args.wsc_margin_alpha
+ * (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
).sum()
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
- loss, nloss = 0., 0
+ loss, nloss = 0.0, 0
ncorrect, nqueries = 0, 0
- for i, label in enumerate(sample['labels']):
+ for i, label in enumerate(sample["labels"]):
query_lprobs = self.get_lprobs(
model,
- sample['query_tokens'][i].unsqueeze(0),
- sample['query_masks'][i].unsqueeze(0),
+ sample["query_tokens"][i].unsqueeze(0),
+ sample["query_masks"][i].unsqueeze(0),
)
cand_lprobs = self.get_lprobs(
model,
- sample['candidate_tokens'][i],
- sample['candidate_masks'][i],
+ sample["candidate_tokens"][i],
+ sample["candidate_masks"][i],
)
pred = (query_lprobs >= cand_lprobs).all().item()
@@ -95,72 +96,72 @@ def forward(self, model, sample, reduce=True):
nloss += 1
loss += self.get_loss(query_lprobs, cand_lprobs)
- id = sample['id'][i].item()
+ id = sample["id"][i].item()
if self.prediction_h is not None:
- print('{}\t{}\t{}'.format(id, pred, label), file=self.prediction_h)
+ print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
if nloss == 0:
loss = torch.tensor(0.0, requires_grad=True)
sample_size = nqueries if nqueries > 0 else 1
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['nsentences'],
- 'sample_size': sample_size,
- 'ncorrect': ncorrect,
- 'nqueries': nqueries,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["nsentences"],
+ "sample_size": sample_size,
+ "ncorrect": ncorrect,
+ "nqueries": nqueries,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
- loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
- ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
- nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
- sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
agg_output = {
- 'loss': loss_sum / sample_size / math.log(2),
- 'ntokens': ntokens,
- 'nsentences': nsentences,
- 'sample_size': sample_size,
+ "loss": loss_sum / sample_size / math.log(2),
+ "ntokens": ntokens,
+ "nsentences": nsentences,
+ "sample_size": sample_size,
}
- ncorrect = sum(log.get('ncorrect', 0) for log in logging_outputs)
- nqueries = sum(log.get('nqueries', 0) for log in logging_outputs)
+ ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
+ nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
if nqueries > 0:
- agg_output['accuracy'] = ncorrect / float(nqueries)
+ agg_output["accuracy"] = ncorrect / float(nqueries)
return agg_output
-@register_criterion('winogrande')
+@register_criterion("winogrande")
class WinograndeCriterion(WSCCriterion):
def forward(self, model, sample, reduce=True):
# compute loss and accuracy
query_lprobs = self.get_lprobs(
model,
- sample['query_tokens'],
- sample['query_masks'],
+ sample["query_tokens"],
+ sample["query_masks"],
)
cand_lprobs = self.get_lprobs(
model,
- sample['candidate_tokens'],
- sample['candidate_masks'],
+ sample["candidate_tokens"],
+ sample["candidate_masks"],
)
pred = query_lprobs >= cand_lprobs
loss = self.get_loss(query_lprobs, cand_lprobs)
- sample_size = sample['query_tokens'].size(0)
+ sample_size = sample["query_tokens"].size(0)
ncorrect = pred.sum().item()
logging_output = {
- 'loss': utils.item(loss.data) if reduce else loss.data,
- 'ntokens': sample['ntokens'],
- 'nsentences': sample['nsentences'],
- 'sample_size': sample_size,
- 'ncorrect': ncorrect,
- 'nqueries': sample_size,
+ "loss": utils.item(loss.data) if reduce else loss.data,
+ "ntokens": sample["ntokens"],
+ "nsentences": sample["nsentences"],
+ "sample_size": sample_size,
+ "ncorrect": ncorrect,
+ "nqueries": sample_size,
}
return loss, sample_size, logging_output
diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py
index fbba0d8964..602ea737ed 100644
--- a/examples/roberta/wsc/wsc_task.py
+++ b/examples/roberta/wsc/wsc_task.py
@@ -10,47 +10,51 @@
import numpy as np
import torch
import torch.nn.functional as F
-
from fairseq import utils
from fairseq.data import (
- data_utils,
Dictionary,
- encoders,
IdDataset,
ListDataset,
NestedDictionaryDataset,
- NumSamplesDataset,
NumelDataset,
+ NumSamplesDataset,
PadDataset,
SortDataset,
+ data_utils,
+ encoders,
)
-from fairseq.tasks import FairseqTask, register_task
+from fairseq.tasks import LegacyFairseqTask, register_task
from . import wsc_utils
-@register_task('wsc')
-class WSCTask(FairseqTask):
+@register_task("wsc")
+class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
- parser.add_argument('data', metavar='DIR',
- help='path to data directory; we load
diff --git a/examples/speech_recognition/kaldi/kaldi_decoder.py b/examples/speech_recognition/kaldi/kaldi_decoder.py
new file mode 100644
index 0000000000..5f62cc58ae
--- /dev/null
+++ b/examples/speech_recognition/kaldi/kaldi_decoder.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ThreadPoolExecutor
+import logging
+from omegaconf import MISSING
+import os
+import torch
+from typing import Optional
+import warnings
+
+
+from dataclasses import dataclass
+from fairseq.dataclass import FairseqDataclass
+from .kaldi_initializer import KaldiInitializerConfig, initalize_kaldi
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class KaldiDecoderConfig(FairseqDataclass):
+ hlg_graph_path: Optional[str] = None
+ output_dict: str = MISSING
+
+ kaldi_initializer_config: Optional[KaldiInitializerConfig] = None
+
+ acoustic_scale: float = 0.5
+ max_active: int = 10000
+ beam_delta: float = 0.5
+ hash_ratio: float = 2.0
+
+ is_lattice: bool = False
+ lattice_beam: float = 10.0
+ prune_interval: int = 25
+ determinize_lattice: bool = True
+ prune_scale: float = 0.1
+ max_mem: int = 0
+ phone_determinize: bool = True
+ word_determinize: bool = True
+ minimize: bool = True
+
+ num_threads: int = 1
+
+
+class KaldiDecoder(object):
+ def __init__(
+ self,
+ cfg: KaldiDecoderConfig,
+ beam: int,
+ nbest: int = 1,
+ ):
+ try:
+ from kaldi.asr import FasterRecognizer, LatticeFasterRecognizer
+ from kaldi.base import set_verbose_level
+ from kaldi.decoder import (
+ FasterDecoder,
+ FasterDecoderOptions,
+ LatticeFasterDecoder,
+ LatticeFasterDecoderOptions,
+ )
+ from kaldi.lat.functions import DeterminizeLatticePhonePrunedOptions
+ from kaldi.fstext import read_fst_kaldi, SymbolTable
+ except:
+ warnings.warn(
+ "pykaldi is required for this functionality. Please install from https://github.com/pykaldi/pykaldi"
+ )
+
+ # set_verbose_level(2)
+
+ self.acoustic_scale = cfg.acoustic_scale
+ self.nbest = nbest
+
+ if cfg.hlg_graph_path is None:
+ assert (
+ cfg.kaldi_initializer_config is not None
+ ), "Must provide hlg graph path or kaldi initializer config"
+ cfg.hlg_graph_path = initalize_kaldi(cfg.kaldi_initializer_config)
+
+ assert os.path.exists(cfg.hlg_graph_path), cfg.hlg_graph_path
+
+ if cfg.is_lattice:
+ self.dec_cls = LatticeFasterDecoder
+ opt_cls = LatticeFasterDecoderOptions
+ self.rec_cls = LatticeFasterRecognizer
+ else:
+ assert self.nbest == 1, "nbest > 1 requires lattice decoder"
+ self.dec_cls = FasterDecoder
+ opt_cls = FasterDecoderOptions
+ self.rec_cls = FasterRecognizer
+
+ self.decoder_options = opt_cls()
+ self.decoder_options.beam = beam
+ self.decoder_options.max_active = cfg.max_active
+ self.decoder_options.beam_delta = cfg.beam_delta
+ self.decoder_options.hash_ratio = cfg.hash_ratio
+
+ if cfg.is_lattice:
+ self.decoder_options.lattice_beam = cfg.lattice_beam
+ self.decoder_options.prune_interval = cfg.prune_interval
+ self.decoder_options.determinize_lattice = cfg.determinize_lattice
+ self.decoder_options.prune_scale = cfg.prune_scale
+ det_opts = DeterminizeLatticePhonePrunedOptions()
+ det_opts.max_mem = cfg.max_mem
+ det_opts.phone_determinize = cfg.phone_determinize
+ det_opts.word_determinize = cfg.word_determinize
+ det_opts.minimize = cfg.minimize
+ self.decoder_options.det_opts = det_opts
+
+ self.output_symbols = {}
+ with open(cfg.output_dict, "r") as f:
+ for line in f:
+ items = line.rstrip().split()
+ assert len(items) == 2
+ self.output_symbols[int(items[1])] = items[0]
+
+ logger.info(f"Loading FST from {cfg.hlg_graph_path}")
+ self.fst = read_fst_kaldi(cfg.hlg_graph_path)
+ self.symbol_table = SymbolTable.read_text(cfg.output_dict)
+
+ self.executor = ThreadPoolExecutor(max_workers=cfg.num_threads)
+
+ def generate(self, models, sample, **unused):
+ """Generate a batch of inferences."""
+ # model.forward normally channels prev_output_tokens into the decoder
+ # separately, but SequenceGenerator directly calls model.encoder
+ encoder_input = {
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
+ }
+ emissions, padding = self.get_emissions(models, encoder_input)
+ return self.decode(emissions, padding)
+
+ def get_emissions(self, models, encoder_input):
+ """Run encoder and normalize emissions"""
+ model = models[0]
+
+ all_encoder_out = [m(**encoder_input) for m in models]
+
+ if len(all_encoder_out) > 1:
+
+ if "encoder_out" in all_encoder_out[0]:
+ encoder_out = {
+ "encoder_out": sum(e["encoder_out"] for e in all_encoder_out)
+ / len(all_encoder_out),
+ "encoder_padding_mask": all_encoder_out[0]["encoder_padding_mask"],
+ }
+ padding = encoder_out["encoder_padding_mask"]
+ else:
+ encoder_out = {
+ "logits": sum(e["logits"] for e in all_encoder_out)
+ / len(all_encoder_out),
+ "padding_mask": all_encoder_out[0]["padding_mask"],
+ }
+ padding = encoder_out["padding_mask"]
+ else:
+ encoder_out = all_encoder_out[0]
+ padding = (
+ encoder_out["padding_mask"]
+ if "padding_mask" in encoder_out
+ else encoder_out["encoder_padding_mask"]
+ )
+
+ if hasattr(model, "get_logits"):
+ emissions = model.get_logits(encoder_out, normalize=True)
+ else:
+ emissions = model.get_normalized_probs(encoder_out, log_probs=True)
+
+ return (
+ emissions.cpu().float().transpose(0, 1),
+ padding.cpu() if padding is not None and padding.any() else None,
+ )
+
+ def decode_one(self, logits, padding):
+ from kaldi.matrix import Matrix
+
+ decoder = self.dec_cls(self.fst, self.decoder_options)
+ asr = self.rec_cls(
+ decoder, self.symbol_table, acoustic_scale=self.acoustic_scale
+ )
+
+ if padding is not None:
+ logits = logits[~padding]
+
+ mat = Matrix(logits.numpy())
+
+ out = asr.decode(mat)
+
+ if self.nbest > 1:
+ from kaldi.fstext import shortestpath
+ from kaldi.fstext.utils import (
+ convert_compact_lattice_to_lattice,
+ convert_lattice_to_std,
+ convert_nbest_to_list,
+ get_linear_symbol_sequence,
+ )
+
+ lat = out["lattice"]
+
+ sp = shortestpath(lat, nshortest=self.nbest)
+
+ sp = convert_compact_lattice_to_lattice(sp)
+ sp = convert_lattice_to_std(sp)
+ seq = convert_nbest_to_list(sp)
+
+ results = []
+ for s in seq:
+ _, o, w = get_linear_symbol_sequence(s)
+ words = list(self.output_symbols[z] for z in o)
+ results.append(
+ {
+ "tokens": words,
+ "words": words,
+ "score": w.value,
+ "emissions": logits,
+ }
+ )
+ return results
+ else:
+ words = out["text"].split()
+ return [
+ {
+ "tokens": words,
+ "words": words,
+ "score": out["likelihood"],
+ "emissions": logits,
+ }
+ ]
+
+ def decode(self, emissions, padding):
+ if padding is None:
+ padding = [None] * len(emissions)
+
+ ret = list(
+ map(
+ lambda e, p: self.executor.submit(self.decode_one, e, p),
+ emissions,
+ padding,
+ )
+ )
+ return ret
diff --git a/examples/speech_recognition/kaldi/kaldi_initializer.py b/examples/speech_recognition/kaldi/kaldi_initializer.py
new file mode 100644
index 0000000000..6d2a2a4b6b
--- /dev/null
+++ b/examples/speech_recognition/kaldi/kaldi_initializer.py
@@ -0,0 +1,698 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+import hydra
+from hydra.core.config_store import ConfigStore
+import logging
+from omegaconf import MISSING, OmegaConf
+import os
+import os.path as osp
+from pathlib import Path
+import subprocess
+from typing import Optional
+
+from fairseq.data.dictionary import Dictionary
+from fairseq.dataclass import FairseqDataclass
+
+script_dir = Path(__file__).resolve().parent
+config_path = script_dir / "config"
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class KaldiInitializerConfig(FairseqDataclass):
+ data_dir: str = MISSING
+ fst_dir: Optional[str] = None
+ in_labels: str = MISSING
+ out_labels: Optional[str] = None
+ wav2letter_lexicon: Optional[str] = None
+ lm_arpa: str = MISSING
+ kaldi_root: str = MISSING
+ blank_symbol: str = "
+ )
+ / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power,
+ )
+
+ metrics.log_derived(
+ "lm_ppl",
+ lambda meters: math.pow(
+ 10,
+ -meters["lm_score_sum"].sum
+ / (
+ meters["num_pred_chars"].sum + meters["nsentences"].sum
+ ), # account for
+ ),
+ )
+ else:
+ metrics.log_derived("weighted_lm_ppl", lambda meters: float("inf"))
+
+ if num_words > 0:
+ if word_lm_sum != 0:
+ metrics.log_derived(
+ "word_lm_ppl",
+ lambda meters: math.pow(
+ 10,
+ -meters["word_lm_sum"].sum
+ / (
+ meters["_num_words"].sum + meters["nsentences"].sum
+ ), # account for
+ ),
+ )
+ metrics.log_derived(
+ "weighted_word_lm_ppl",
+ lambda meters: math.pow(
+ 10,
+ -meters["word_lm_sum"].sum
+ / (
+ meters["_num_words"].sum + meters["nsentences"].sum
+ ), # account for
+ )
+ / meters["vocab_seen_pct"].avg ** self.cfg.vocab_usage_power,
+ )
+
+ if self.cfg.word_kenlm_path is not None:
+ metrics.log_derived(
+ "kaldi_score",
+ lambda meters: meters["kaldi_score_sum"].sum
+ / meters["nsentences"].sum,
+ )
+
+ def build_model(self, cfg: FairseqDataclass):
+ model = super().build_model(cfg)
+
+ return model
diff --git a/examples/wav2vec/unsupervised/w2vu_generate.py b/examples/wav2vec/unsupervised/w2vu_generate.py
new file mode 100644
index 0000000000..6177239dc7
--- /dev/null
+++ b/examples/wav2vec/unsupervised/w2vu_generate.py
@@ -0,0 +1,707 @@
+#!/usr/bin/env python3 -u
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Run inference for pre-processed data with a trained model.
+"""
+
+import ast
+from collections import namedtuple
+from dataclasses import dataclass, field
+from enum import Enum, auto
+import hydra
+from hydra.core.config_store import ConfigStore
+import logging
+import math
+import os
+from omegaconf import OmegaConf
+from typing import Optional
+import sys
+
+import editdistance
+import torch
+
+from hydra.core.hydra_config import HydraConfig
+
+from fairseq import checkpoint_utils, progress_bar, tasks, utils
+from fairseq.data.data_utils import post_process
+from fairseq.dataclass.configs import FairseqDataclass, FairseqConfig
+from fairseq.logging.meters import StopwatchMeter
+from omegaconf import open_dict
+
+from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoderConfig
+
+logging.root.setLevel(logging.INFO)
+logging.basicConfig(stream=sys.stdout, level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class DecoderType(Enum):
+ VITERBI = auto()
+ KENLM = auto()
+ FAIRSEQ = auto()
+ KALDI = auto()
+
+
+@dataclass
+class UnsupGenerateConfig(FairseqDataclass):
+ fairseq: FairseqConfig = FairseqConfig()
+ lm_weight: float = field(
+ default=2.0,
+ metadata={"help": "language model weight"},
+ )
+ w2l_decoder: DecoderType = field(
+ default=DecoderType.VITERBI,
+ metadata={"help": "type of decoder to use"},
+ )
+ kaldi_decoder_config: Optional[KaldiDecoderConfig] = None
+ lexicon: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "path to lexicon. This is also used to 'phonemize' for unsupvised param tuning"
+ },
+ )
+ lm_model: Optional[str] = field(
+ default=None,
+ metadata={"help": "path to language model (kenlm or fairseq)"},
+ )
+ unit_lm: bool = field(
+ default=False,
+ metadata={"help": "whether to use unit lm"},
+ )
+ beam_threshold: float = field(
+ default=50.0,
+ metadata={"help": "beam score threshold"},
+ )
+ beam_size_token: float = field(
+ default=100.0,
+ metadata={"help": "max tokens per beam"},
+ )
+ beam: int = field(
+ default=5,
+ metadata={"help": "decoder beam size"},
+ )
+ nbest: int = field(
+ default=1,
+ metadata={"help": "number of results to return"},
+ )
+ word_score: float = field(
+ default=1.0,
+ metadata={"help": "word score to add at end of word"},
+ )
+ unk_weight: float = field(
+ default=-math.inf,
+ metadata={"help": "unknown token weight"},
+ )
+ sil_weight: float = field(
+ default=0.0,
+ metadata={"help": "silence token weight"},
+ )
+ targets: Optional[str] = field(
+ default=None,
+ metadata={"help": "extension of ground truth labels to compute UER"},
+ )
+ results_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "where to store results"},
+ )
+ post_process: Optional[str] = field(
+ default=None,
+ metadata={"help": "how to post process results"},
+ )
+ vocab_usage_power: float = field(
+ default=2,
+ metadata={"help": "for unsupervised param tuning"},
+ )
+
+ viterbi_transcript: Optional[str] = field(
+ default=None,
+ metadata={"help": "for unsupervised param tuning"},
+ )
+ min_lm_ppl: float = field(
+ default=0,
+ metadata={"help": "for unsupervised param tuning"},
+ )
+ min_vt_uer: float = field(
+ default=0,
+ metadata={"help": "for unsupervised param tuning"},
+ )
+
+ blank_weight: float = field(
+ default=0,
+ metadata={"help": "value to add or set for blank emission"},
+ )
+ blank_mode: str = field(
+ default="set",
+ metadata={
+ "help": "can be add or set, how to modify blank emission with blank weight"
+ },
+ )
+ sil_is_blank: bool = field(
+ default=False,
+ metadata={"help": "if true, "
+ silence_symbol: Optional[str] = None
+
+
+def create_units(fst_dir: Path, in_labels: str, vocab: Dictionary) -> Path:
+ in_units_file = fst_dir / f"kaldi_dict.{in_labels}.txt"
+ if not in_units_file.exists():
+
+ logger.info(f"Creating {in_units_file}")
+
+ with open(in_units_file, "w") as f:
+ print("
+ lm_score += cur_score
+ w_cnt += cur_cnt
+ logger.debug((
+ f"======================\n"
+ f"score sum/avg = {cur_score:.2f}/{cur_score/cur_cnt:.2f}\n"
+ f"hyp = {hyp}"
+ ))
+ lm_ppl = math.pow(10, -lm_score / w_cnt)
+ logger.debug(f"lm ppl = {lm_ppl:.2f}; num. of words = {w_cnt}")
+ return lm_ppl
+
+def main():
+ args = get_parser().parse_args()
+ logger.debug(f"Args: {args}")
+
+ ref_uid_to_tra = load_tra(args.ref_tra)
+ hyp_uid_to_tra = load_tra(args.hyp_tra)
+ assert not bool(set(hyp_uid_to_tra.keys()) - set(ref_uid_to_tra.keys()))
+
+ lm = kenlm.Model(args.kenlm_path)
+ skipwords = set(args.skipwords.split(","))
+ def compute_lm_score(s):
+ s = " ".join(w for w in s.split() if w not in skipwords)
+ s = s.upper() if args.uppercase else s
+ return lm.score(s)
+
+ g2p, g2p_dict = None, None
+ if args.phonemize:
+ if args.phonemize_lexicon:
+ g2p_dict = load_lex(args.phonemize_lexicon)
+ else:
+ g2p = G2p()
+
+ wer = compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p, g2p_dict)
+ lm_ppl = compute_lm_ppl(hyp_uid_to_tra, compute_lm_score)
+
+ gt_wer = -math.inf
+ if args.gt_tra:
+ gt_uid_to_tra = load_tra(args.gt_tra)
+ gt_wer = compute_wer(gt_uid_to_tra, hyp_uid_to_tra, None, None)
+
+ score = math.log(lm_ppl) * max(wer, args.min_vt_uer)
+ logging.info(f"{args.hyp_tra}: score={score:.4f}; wer={wer*100:.2f}%; lm_ppl={lm_ppl:.4f}; gt_wer={gt_wer*100:.2f}%")
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh
new file mode 100755
index 0000000000..b34c5b6e06
--- /dev/null
+++ b/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+split="dev_other"
+ref_txt="" # ground truth transcript path
+psd_txt="" # pseudo transcript path
+get_best_wer=true
+dec_name="decode"
+graph_name="graph"
+kenlm_path=/checkpoint/abaevski/data/speech/libri/librispeech_lm_novox.phnc_o6.bin
+
+. ./cmd.sh
+. ./path.sh
+. parse_options.sh
+
+exp_root=$1
+unsup_args=""
+if [ $# -ge 2 ]; then
+ unsup_args=$2
+fi
+
+set -eu
+
+if [ ! -z $ref_txt ] && $get_best_wer; then
+ echo "==== WER w.r.t. real transcript (select based on unsupervised metric)"
+ for x in $exp_root/*/${dec_name}_${split}*; do
+ lang=$(dirname $x)/$graph_name
+
+ (
+ for tra in $x/scoring/*.tra; do
+ cat $tra | utils/int2sym.pl -f 2- $lang/words.txt | sed 's:", 0
+EOS_TOKEN, EOS_TOKEN_ID = "", 2
+PAD_TOKEN, PAD_TOKEN_ID = "'
+EOS_TOK = ''
+
+def text_to_sequence(text, cleaner_names):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
+
+ Args:
+ text: string to convert to a sequence
+ cleaner_names: names of the cleaner functions to run the text through
+
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ sequence = []
+
+ # Check for curly braces and treat their contents as ARPAbet:
+ while len(text):
+ m = _curly_re.match(text)
+ if not m:
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
+ break
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
+ sequence += _arpabet_to_sequence(m.group(2))
+ text = m.group(3)
+
+ return sequence
+
+
+def sample_code_chunk(code, size):
+ assert(size > 0 and size <= len(code))
+ start = np.random.randint(len(code) - size + 1)
+ end = start + size
+ return code[start:end], start, end
+
+
+def code_to_sequence(code, code_dict, collapse_code):
+ if collapse_code:
+ prev_c = None
+ sequence = []
+ for c in code:
+ if c in code_dict and c != prev_c:
+ sequence.append(code_dict[c])
+ prev_c = c
+ else:
+ sequence = [code_dict[c] for c in code if c in code_dict]
+ if len(sequence) < 0.95 * len(code):
+ print('WARNING : over 5%% codes are OOV')
+
+ return sequence
+
+
+def sequence_to_text(sequence):
+ '''Converts a sequence of IDs back to a string'''
+ result = ''
+ for symbol_id in sequence:
+ if symbol_id in _id_to_symbol:
+ s = _id_to_symbol[symbol_id]
+ # Enclose ARPAbet back in curly braces:
+ if len(s) > 1 and s[0] == '@':
+ s = '{%s}' % s[1:]
+ result += s
+ return result.replace('}{', ' ')
+
+
+def sequence_to_code(sequence, code_dict):
+ '''Analogous to sequence_to_text'''
+ id_to_code = {i: c for c, i in code_dict.items()}
+ return ' '.join([id_to_code[i] for i in sequence])
+
+
+def _clean_text(text, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception('Unknown cleaner: %s' % name)
+ text = cleaner(text)
+ return text
+
+
+def _symbols_to_sequence(symbols):
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
+
+
+def _arpabet_to_sequence(text):
+ return _symbols_to_sequence(['@' + s for s in text.split()])
+
+
+def _should_keep_symbol(s):
+ return s in _symbol_to_id and s != '_' and s != '~'
diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
new file mode 100644
index 0000000000..66a426d222
--- /dev/null
+++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
@@ -0,0 +1,167 @@
+import collections
+import io
+import json
+import librosa
+import numpy as np
+import soundfile as sf
+import time
+import torch
+from scipy.io.wavfile import read
+from .text import SOS_TOK, EOS_TOK
+
+
+def get_mask_from_lengths(lengths):
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
+ mask = (ids < lengths.unsqueeze(1))
+ return mask
+
+
+def load_wav_to_torch(full_path, sr=None):
+ data, sr = librosa.load(full_path, sr=sr)
+ data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling
+ data = data * 32768.0 # match values loaded by scipy
+ return torch.FloatTensor(data.astype(np.float32)), sr
+
+
+def read_binary_audio(bin_data, tar_sr=None):
+ """
+ read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32`
+ `numpy.ndarray`
+
+ RETURNS:
+ data (np.ndarray) : audio of shape (n,) or (2, n)
+ tar_sr (int) : sample rate
+ """
+ data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32')
+ data = data.T
+ if (tar_sr is not None) and (ori_sr != tar_sr):
+ data = librosa.resample(data, ori_sr, tar_sr)
+ else:
+ tar_sr = ori_sr
+ data = np.clip(data, -1, 1)
+ data = data * 32768.0
+ return torch.FloatTensor(data.astype(np.float32)), tar_sr
+
+
+def load_filepaths_and_text(filename):
+ with open(filename, encoding='utf-8') as f:
+ data = [json.loads(line.rstrip()) for line in f]
+ return data
+
+
+def to_gpu(x):
+ x = x.contiguous()
+
+ if torch.cuda.is_available():
+ x = x.cuda(non_blocking=True)
+ return torch.autograd.Variable(x)
+
+
+def load_code_dict(path, add_sos=False, add_eos=False):
+ if not path:
+ return {}
+
+ with open(path, 'r') as f:
+ codes = ['_'] + [line.rstrip() for line in f] # '_' for pad
+ code_dict = {c: i for i, c in enumerate(codes)}
+
+ if add_sos:
+ code_dict[SOS_TOK] = len(code_dict)
+ if add_eos:
+ code_dict[EOS_TOK] = len(code_dict)
+ assert(set(code_dict.values()) == set(range(len(code_dict))))
+
+ return code_dict
+
+
+def load_obs_label_dict(path):
+ if not path:
+ return {}
+ with open(path, 'r') as f:
+ obs_labels = [line.rstrip() for line in f]
+ return {c: i for i, c in enumerate(obs_labels)}
+
+
+# A simple timer class inspired from `tnt.TimeMeter`
+class CudaTimer:
+ def __init__(self, keys):
+ self.keys = keys
+ self.reset()
+
+ def start(self, key):
+ s = torch.cuda.Event(enable_timing=True)
+ s.record()
+ self.start_events[key].append(s)
+ return self
+
+ def stop(self, key):
+ e = torch.cuda.Event(enable_timing=True)
+ e.record()
+ self.end_events[key].append(e)
+ return self
+
+ def reset(self):
+ self.start_events = collections.defaultdict(list)
+ self.end_events = collections.defaultdict(list)
+ self.running_times = collections.defaultdict(float)
+ self.n = collections.defaultdict(int)
+ return self
+
+ def value(self):
+ self._synchronize()
+ return {k: self.running_times[k] / self.n[k] for k in self.keys}
+
+ def _synchronize(self):
+ torch.cuda.synchronize()
+ for k in self.keys:
+ starts = self.start_events[k]
+ ends = self.end_events[k]
+ if len(starts) == 0:
+ raise ValueError("Trying to divide by zero in TimeMeter")
+ if len(ends) != len(starts):
+ raise ValueError("Call stop before checking value!")
+ time = 0
+ for start, end in zip(starts, ends):
+ time += start.elapsed_time(end)
+ self.running_times[k] += time * 1e-3
+ self.n[k] += len(starts)
+ self.start_events = collections.defaultdict(list)
+ self.end_events = collections.defaultdict(list)
+
+
+# Used to measure the time taken for multiple events
+class Timer:
+ def __init__(self, keys):
+ self.keys = keys
+ self.n = {}
+ self.running_time = {}
+ self.total_time = {}
+ self.reset()
+
+ def start(self, key):
+ self.running_time[key] = time.time()
+ return self
+
+ def stop(self, key):
+ self.total_time[key] = time.time() - self.running_time[key]
+ self.n[key] += 1
+ self.running_time[key] = None
+ return self
+
+ def reset(self):
+ for k in self.keys:
+ self.total_time[k] = 0
+ self.running_time[k] = None
+ self.n[k] = 0
+ return self
+
+ def value(self):
+ vals = {}
+ for k in self.keys:
+ if self.n[k] == 0:
+ raise ValueError("Trying to divide by zero in TimeMeter")
+ else:
+ vals[k] = self.total_time[k] / self.n[k]
+ return vals
+
diff --git a/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py b/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
new file mode 100644
index 0000000000..6a6585e8b6
--- /dev/null
+++ b/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
@@ -0,0 +1,40 @@
+# import sys
+# sys.path.append('tacotron2')
+import torch
+from .layers import STFT
+
+
+class Denoiser(torch.nn.Module):
+ """ Removes model bias from audio produced with waveglow """
+
+ def __init__(self, waveglow, filter_length=1024, n_overlap=4,
+ win_length=1024, mode='zeros'):
+ super(Denoiser, self).__init__()
+ self.stft = STFT(filter_length=filter_length,
+ hop_length=int(filter_length/n_overlap),
+ win_length=win_length).cuda()
+ if mode == 'zeros':
+ mel_input = torch.zeros(
+ (1, 80, 88),
+ dtype=waveglow.upsample.weight.dtype,
+ device=waveglow.upsample.weight.device)
+ elif mode == 'normal':
+ mel_input = torch.randn(
+ (1, 80, 88),
+ dtype=waveglow.upsample.weight.dtype,
+ device=waveglow.upsample.weight.device)
+ else:
+ raise Exception("Mode {} if not supported".format(mode))
+
+ with torch.no_grad():
+ bias_audio = waveglow.infer(mel_input, sigma=0.0).float()
+ bias_spec, _ = self.stft.transform(bias_audio)
+
+ self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
+
+ def forward(self, audio, strength=0.1):
+ audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
+ audio_spec_denoised = audio_spec - self.bias_spec * strength
+ audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
+ audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
+ return audio_denoised
diff --git a/examples/textless_nlp/gslm/unit2speech/tts_data.py b/examples/textless_nlp/gslm/unit2speech/tts_data.py
new file mode 100644
index 0000000000..eb0f7c360d
--- /dev/null
+++ b/examples/textless_nlp/gslm/unit2speech/tts_data.py
@@ -0,0 +1,52 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import numpy as np
+from examples.textless_nlp.gslm.unit2speech.tacotron2.text import (
+ EOS_TOK,
+ SOS_TOK,
+ code_to_sequence,
+ text_to_sequence,
+)
+from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import (
+ load_code_dict,
+)
+
+
+class TacotronInputDataset:
+ def __init__(self, hparams, append_str=""):
+ self.is_text = getattr(hparams, "text_or_code", "text") == "text"
+ if not self.is_text:
+ self.code_dict = load_code_dict(hparams.code_dict)
+ self.code_key = hparams.code_key
+ self.add_sos = hparams.add_sos
+ self.add_eos = hparams.add_eos
+ self.collapse_code = hparams.collapse_code
+ self.append_str = append_str
+
+ def process_code(self, inp_str):
+ inp_toks = inp_str.split()
+ if self.add_sos:
+ inp_toks = [SOS_TOK] + inp_toks
+ if self.add_eos:
+ inp_toks = inp_toks + [EOS_TOK]
+ return code_to_sequence(inp_toks, self.code_dict, self.collapse_code)
+
+ def process_text(self, inp_str):
+ return text_to_sequence(inp_str, ["english_cleaners"])
+
+ def get_tensor(self, inp_str):
+ # uid, txt, inp_str = self._get_data(idx)
+ inp_str = inp_str + self.append_str
+ if self.is_text:
+ inp_toks = self.process_text(inp_str)
+ else:
+ inp_toks = self.process_code(inp_str)
+ return torch.from_numpy(np.array(inp_toks)).long()
+
+ def __len__(self):
+ return len(self.data)
diff --git a/examples/textless_nlp/gslm/unit2speech/utils.py b/examples/textless_nlp/gslm/unit2speech/utils.py
new file mode 100644
index 0000000000..7aced08d38
--- /dev/null
+++ b/examples/textless_nlp/gslm/unit2speech/utils.py
@@ -0,0 +1,55 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2
+from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import (
+ Denoiser,
+)
+
+
+def load_quantized_audio_from_file(file_path):
+ base_fname_batch, quantized_units_batch = [], []
+ with open(file_path) as f:
+ for line in f:
+ base_fname, quantized_units_str = line.rstrip().split("|")
+ quantized_units = [int(q) for q in quantized_units_str.split(" ")]
+ base_fname_batch.append(base_fname)
+ quantized_units_batch.append(quantized_units)
+ return base_fname_batch, quantized_units_batch
+
+
+def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0):
+ assert inp.size(0) == 1
+ inp = inp.cuda()
+ if lab is not None:
+ lab = torch.LongTensor(1).cuda().fill_(lab)
+
+ with torch.no_grad():
+ _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True)
+ aud = waveglow.infer(mel, sigma=0.666)
+ aud_dn = denoiser(aud, strength=strength).squeeze(1)
+ return mel, aud, aud_dn, has_eos
+
+
+def load_tacotron(tacotron_model_path, max_decoder_steps):
+ ckpt_dict = torch.load(tacotron_model_path)
+ hparams = ckpt_dict["hparams"]
+ hparams.max_decoder_steps = max_decoder_steps
+ sr = hparams.sampling_rate
+ model = Tacotron2(hparams)
+ model.load_state_dict(ckpt_dict["model_dict"])
+ model = model.cuda().eval().half()
+ return model, sr, hparams
+
+
+def load_waveglow(waveglow_path):
+ waveglow = torch.load(waveglow_path)["model"]
+ waveglow = waveglow.cuda().eval().half()
+ for k in waveglow.convinv:
+ k.float()
+ denoiser = Denoiser(waveglow)
+ return waveglow, denoiser
diff --git a/examples/translation/README.md b/examples/translation/README.md
index 67e99f6efd..2941f5eb84 100644
--- a/examples/translation/README.md
+++ b/examples/translation/README.md
@@ -175,9 +175,11 @@ mkdir -p checkpoints/fconv_wmt_en_de
fairseq-train \
data-bin/wmt17_en_de \
--arch fconv_wmt_en_de \
- --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
+ --dropout 0.2 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
- --lr-scheduler fixed --force-anneal 50 \
+ --optimizer nag --clip-norm 0.1 \
+ --lr 0.5 --lr-scheduler fixed --force-anneal 50 \
+ --max-tokens 4000 \
--save-dir checkpoints/fconv_wmt_en_de
# Evaluate
@@ -205,10 +207,12 @@ fairseq-preprocess \
mkdir -p checkpoints/fconv_wmt_en_fr
fairseq-train \
data-bin/wmt14_en_fr \
- --lr 0.5 --clip-norm 0.1 --dropout 0.1 --max-tokens 3000 \
- --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
- --lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_fr \
+ --dropout 0.1 \
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
+ --optimizer nag --clip-norm 0.1 \
+ --lr 0.5 --lr-scheduler fixed --force-anneal 50 \
+ --max-tokens 3000 \
--save-dir checkpoints/fconv_wmt_en_fr
# Evaluate
@@ -225,7 +229,7 @@ train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
-languages and use interactive.py and sacrebleu for scoring the test set.
+languages and use fairseq-interactive and sacrebleu for scoring the test set.
```bash
# First install sacrebleu and sentencepiece
@@ -259,12 +263,12 @@ fairseq-preprocess --source-lang fr --target-lang en \
mkdir -p checkpoints/multilingual_transformer
CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
--max-epoch 50 \
- --ddp-backend=no_c10d \
+ --ddp-backend=legacy_ddp \
--task multilingual_translation --lang-pairs de-en,fr-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' \
- --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
+ --lr 0.0005 --lr-scheduler inverse_sqrt \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy \
--dropout 0.3 --weight-decay 0.0001 \
diff --git a/examples/translation/prepare-iwslt14.sh b/examples/translation/prepare-iwslt14.sh
index 0bf0dc2a2e..2fb6643fbc 100644
--- a/examples/translation/prepare-iwslt14.sh
+++ b/examples/translation/prepare-iwslt14.sh
@@ -15,7 +15,7 @@ CLEAN=$SCRIPTS/training/clean-corpus-n.perl
BPEROOT=subword-nmt/subword_nmt
BPE_TOKENS=10000
-URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz"
+URL="http://dl.fbaipublicfiles.com/fairseq/data/iwslt14/de-en.tgz"
GZ=de-en.tgz
if [ ! -d "$SCRIPTS" ]; then
diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md
index 33f1bee5cb..2e5c8af617 100644
--- a/examples/translation_moe/README.md
+++ b/examples/translation_moe/README.md
@@ -15,16 +15,16 @@ The model is trained with online responsibility assignment and shared parameteri
The following command will train a `hMoElp` model with `3` experts:
```bash
-fairseq-train --ddp-backend='no_c10d' \
+fairseq-train --ddp-backend='legacy_ddp' \
data-bin/wmt17_en_de \
--max-update 100000 \
- --task translation_moe --user-dir examples/translation_moe/src \
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
- --lr 0.0007 --min-lr 1e-09 \
+ --lr 0.0007 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584
```
@@ -37,7 +37,7 @@ For example, to generate from expert 0:
fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
- --task translation_moe --user-dir examples/translation_moe/src \
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0
@@ -61,7 +61,7 @@ for EXPERT in $(seq 0 2); do \
--beam 1 \
--bpe subword_nmt --bpe-codes $BPE_CODE \
--buffer-size 500 --max-tokens 6000 \
- --task translation_moe --user-dir examples/translation_moe/src \
+ --task translation_moe --user-dir examples/translation_moe/translation_moe_src \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT ; \
diff --git a/examples/translation_moe/score.py b/examples/translation_moe/score.py
index 8e207093db..9a529a9850 100644
--- a/examples/translation_moe/score.py
+++ b/examples/translation_moe/score.py
@@ -12,9 +12,9 @@
"""
import argparse
-from itertools import chain
-import sys
import random
+import sys
+from itertools import chain
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
@@ -22,17 +22,21 @@
def main():
parser = argparse.ArgumentParser(sys.argv[0])
- parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
- help='path to system output')
- parser.add_argument('--ref', default='', metavar='FILE',
- help='path to references')
- parser.add_argument('--output', default='', metavar='FILE',
- help='print outputs into a pretty format')
+ parser.add_argument(
+ "--sys", nargs="*", default="", metavar="FILE", help="path to system output"
+ )
+ parser.add_argument("--ref", default="", metavar="FILE", help="path to references")
+ parser.add_argument(
+ "--output",
+ default="",
+ metavar="FILE",
+ help="print outputs into a pretty format",
+ )
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
- print('pairwise BLEU: %.2f' % pairwise(hypos))
+ print("pairwise BLEU: %.2f" % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
@@ -55,18 +59,21 @@ def load_sys(paths):
with open(path) as f:
for line in f:
line = line.rstrip()
- if line.startswith(('S-', 'T-', 'H-')):
- i = int(line[line.find('-')+1:line.find('\t')])
- if line.startswith('S-'):
- src[i] = line.split('\t')[1]
- if line.startswith('T-'):
- tgt[i] = line.split('\t')[1]
- if line.startswith('H-'):
+ # S: source
+ # T: target
+ # D: detokenized system output
+ if line.startswith(("S-", "T-", "D-")):
+ i = int(line[line.find("-") + 1 : line.find("\t")])
+ if line.startswith("S-"):
+ src[i] = line.split("\t")[1]
+ if line.startswith("T-"):
+ tgt[i] = line.split("\t")[1]
+ if line.startswith("D-"):
if i not in hypos:
hypos[i] = []
log_probs[i] = []
- hypos[i].append(line.split('\t')[2])
- log_probs[i].append(float(line.split('\t')[1]))
+ hypos[i].append(line.split("\t")[2])
+ log_probs[i].append(float(line.split("\t")[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
@@ -76,34 +83,34 @@ def load_ref(path):
src, tgt, refs = [], [], []
i = 0
while i < len(lines):
- if lines[i].startswith('S-'):
- src.append(lines[i].split('\t')[1].rstrip())
+ if lines[i].startswith("S-"):
+ src.append(lines[i].split("\t")[1].rstrip())
i += 1
- elif lines[i].startswith('T-'):
- tgt.append(lines[i].split('\t')[1].rstrip())
+ elif lines[i].startswith("T-"):
+ tgt.append(lines[i].split("\t")[1].rstrip())
i += 1
else:
a = []
- while i < len(lines) and lines[i].startswith('R'):
- a.append(lines[i].split('\t')[1].rstrip())
+ while i < len(lines) and lines[i].startswith("R"):
+ a.append(lines[i].split("\t")[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
- with open(path, 'w') as f:
+ with open(path, "w") as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
- f.write(s + '\n')
- f.write(t + '\n')
- f.write('\n')
+ f.write(s + "\n")
+ f.write(t + "\n")
+ f.write("\n")
for h, lp in zip(hs, lps):
- f.write('\t%f\t%s\n' % (lp, h.strip()))
- f.write('------------------------------------------------------\n')
+ f.write("\t%f\t%s\n" % (lp, h.strip()))
+ f.write("------------------------------------------------------\n")
def corpus_bleu(sys_stream, ref_streams):
- bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
+ bleu = _corpus_bleu(sys_stream, ref_streams, tokenize="none")
return bleu.score
@@ -113,9 +120,11 @@ def sentence_bleu(hypothesis, reference):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
- bleu.counts, bleu.totals,
- bleu.sys_len, bleu.ref_len,
- smooth='exp', smooth_floor=0.0,
+ bleu.counts,
+ bleu.totals,
+ bleu.sys_len,
+ bleu.ref_len,
+ smooth_method="exp",
)
return bleu.score
@@ -147,7 +156,7 @@ def multi_ref(refs, hypos):
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
- print('#refs covered: %.2f' % (ref_cnt / len(refs)))
+ print("#refs covered: %.2f" % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
@@ -157,33 +166,32 @@ def multi_ref(refs, hypos):
k = len(hypos)
m = len(refs)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
- duplicated_refs = [
- [ref for ref in refs_i for _ in range(k)]
- for refs_i in refs
- ]
+ duplicated_refs = [[ref for ref in refs_i for _ in range(k)] for refs_i in refs]
loo_bleus = []
for held_out_ref in range(m):
- remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
+ remaining_refs = (
+ duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref + 1 :]
+ )
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
- print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
+ print("average multi-reference BLEU (leave-one-out): %.2f" % np.mean(loo_bleus))
def intra_ref(refs):
- print('ref pairwise BLEU: %.2f' % pairwise(refs))
+ print("ref pairwise BLEU: %.2f" % pairwise(refs))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
- rest = refs[:i] + refs[i+1:]
+ rest = refs[:i] + refs[i + 1 :]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
- print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
+ print("multi-reference BLEU (leave-one-out): %.2f" % bleu)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/examples/translation_moe/src/__init__.py b/examples/translation_moe/translation_moe_src/__init__.py
similarity index 100%
rename from examples/translation_moe/src/__init__.py
rename to examples/translation_moe/translation_moe_src/__init__.py
diff --git a/examples/translation_moe/src/logsumexp_moe.py b/examples/translation_moe/translation_moe_src/logsumexp_moe.py
similarity index 95%
rename from examples/translation_moe/src/logsumexp_moe.py
rename to examples/translation_moe/translation_moe_src/logsumexp_moe.py
index 0379f226b0..fb299daecb 100644
--- a/examples/translation_moe/src/logsumexp_moe.py
+++ b/examples/translation_moe/translation_moe_src/logsumexp_moe.py
@@ -21,6 +21,6 @@ def forward(ctx, logp, posterior, dim=-1):
@staticmethod
def backward(ctx, grad_output):
- posterior, = ctx.saved_tensors
+ (posterior,) = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None
diff --git a/examples/translation_moe/src/mean_pool_gating_network.py b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
similarity index 80%
rename from examples/translation_moe/src/mean_pool_gating_network.py
rename to examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
index 25743b4e98..efc7ae40bf 100644
--- a/examples/translation_moe/src/mean_pool_gating_network.py
+++ b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
@@ -26,15 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None):
def forward(self, encoder_out):
if not (
- hasattr(encoder_out, 'encoder_out')
- and hasattr(encoder_out, 'encoder_padding_mask')
- and encoder_out.encoder_out.size(2) == self.embed_dim
+ "encoder_out" in encoder_out
+ and "encoder_padding_mask" in encoder_out
+ and encoder_out["encoder_out"][0].size(2) == self.embed_dim
):
- raise ValueError('Unexpected format for encoder_out')
+ raise ValueError("Unexpected format for encoder_out")
# mean pooling over time
- encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
- encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
+ encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T
+ encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
diff --git a/examples/translation_moe/src/translation_moe.py b/examples/translation_moe/translation_moe_src/translation_moe.py
similarity index 53%
rename from examples/translation_moe/src/translation_moe.py
rename to examples/translation_moe/translation_moe_src/translation_moe.py
index 61e4bed809..7f28c32dd6 100644
--- a/examples/translation_moe/src/translation_moe.py
+++ b/examples/translation_moe/translation_moe_src/translation_moe.py
@@ -3,17 +3,52 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass, field
import torch
+from omegaconf import II
from fairseq import metrics, utils
+from fairseq.dataclass import ChoiceEnum
from fairseq.tasks import register_task
-from fairseq.tasks.translation import TranslationTask
+from fairseq.tasks.translation import TranslationConfig, TranslationTask
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
-@register_task('translation_moe')
+METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"])
+
+
+@dataclass
+class TranslationMoEConfig(TranslationConfig):
+ method: METHOD_CHOICES = field(
+ default="hMoEup",
+ metadata={"help": "MoE method"},
+ )
+ num_experts: int = field(
+ default=3,
+ metadata={"help": "number of experts"},
+ )
+ mean_pool_gating_network: bool = field(
+ default=False,
+ metadata={"help": "use a simple mean-pooling gating network"},
+ )
+ mean_pool_gating_network_dropout: float = field(
+ default=0,
+ metadata={"help": "dropout for mean-pooling gating network"},
+ )
+ mean_pool_gating_network_encoder_dim: int = field(
+ default=0,
+ metadata={"help": "encoder output dim for mean-pooling gating network"},
+ )
+ gen_expert: int = field(
+ default=0,
+ metadata={"help": "which expert to use for generation"},
+ )
+ sentence_avg: bool = II("optimization.sentence_avg")
+
+
+@register_task("translation_moe", dataclass=TranslationMoEConfig)
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
@@ -38,90 +73,79 @@ class TranslationMoETask(TranslationTask):
:prog:
"""
- @staticmethod
- def add_args(parser):
- """Add task-specific arguments to the parser."""
- # fmt: off
- TranslationTask.add_args(parser)
- parser.add_argument('--method', default='hMoEup',
- choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
- parser.add_argument('--num-experts', default=3, type=int, metavar='N',
- help='number of experts')
- parser.add_argument('--mean-pool-gating-network', action='store_true',
- help='use a simple mean-pooling gating network')
- parser.add_argument('--mean-pool-gating-network-dropout', type=float,
- help='dropout for mean-pooling gating network')
- parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float,
- help='encoder output dim for mean-pooling gating network')
- parser.add_argument('--gen-expert', type=int, default=0,
- help='which expert to use for generation')
- # fmt: on
-
- def __init__(self, args, src_dict, tgt_dict):
- if args.method == 'sMoElp':
+ cfg: TranslationMoEConfig
+
+ def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict):
+ if cfg.method == "sMoElp":
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
- elif args.method == 'sMoEup':
+ elif cfg.method == "sMoEup":
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
- elif args.method == 'hMoElp':
+ elif cfg.method == "hMoElp":
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
- elif args.method == 'hMoEup':
+ elif cfg.method == "hMoEup":
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
# add indicator tokens for each expert
- for i in range(args.num_experts):
+ for i in range(cfg.num_experts):
# add to both dictionaries in case we're sharing embeddings
- src_dict.add_symbol('
([(Schneider et al., 2019)](https://arxiv.org/abs/1904.05862)) | 32.5M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt)
+Description | Dataset | Model
+---|---|---
+Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt)
#### Example usage:
```python
import torch
-from fairseq.models.wav2vec import Wav2VecModel
+import fairseq
-cp = torch.load('/path/to/wav2vec.pt')
-model = Wav2VecModel.build_model(cp['args'], task=None)
-model.load_state_dict(cp['model'])
+cp_path = '/path/to/wav2vec.pt'
+model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
+model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
@@ -25,23 +232,76 @@ c = model.feature_aggregator(z)
## Training a new model with the CLI tools
-Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length)
+Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length)
### Prepare training data manifest:
```
-$ python scripts/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
+$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav
```
### Train a wav2vec model:
```
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
---arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \
+--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
---skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 \
---max-sample-size 150000 --max-tokens 1500000 ---skip-invalid-size-inputs-valid-test
+--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \
+--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test
+```
+
+### Run wav2vec2 pre-training on Google Cloud TPUs:
+
+Wav2Vec2 is now supported on TPUs! It's currently pre-training only.
+
+#### Using hydra on a v3-8:
+
+```
+$ OMP_NUM_THREADS=1 fairseq-hydra-train \
+ task.data=/manifest/path \
+ --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \
+ --config-name wav2vec2_large_librivox_tpu.yaml
+```
+
+#### Using command line arguments on a v3-8:
+Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently.
+
+```
+$ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
+--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \
+--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
+--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
+--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \
+--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \
+--tpu --distributed-world-size 8 --num-batch-buckets 3 --enable-padding \
+--encoder-layerdrop 0 --mask-channel-prob 0.1
+```
+
+#### Using hydra on a pod slice (v3-N with N > 8):
+
+```
+$ OMP_NUM_THREADS=1 fairseq-hydra-train \
+ task.data=/manifest/path \
+ --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \
+ --config-name wav2vec2_large_librivox_tpu-pod.yaml # edit distributed-world-size accordingly
+```
+
+#### Using command line arguments on a pod slice (v3-N with N > 8):
+Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently.
+
+```
+$ python -m torch_xla.distributed.xla_dist \
+ --tpu ${TPUNAME} --conda-env=torch-xla-${TORCH_XLA_VERSION} --env OMP_NUM_THREADS=1 \
+ -- \
+python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
+--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \
+--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
+--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
+--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \
+--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \
+--tpu --distributed-world-size ${WORLD_SIZE} --num-batch-buckets 3 --enable-padding \
+--encoder-layerdrop 0 --mask-channel-prob 0.1
```
### Extract embeddings from the downstream task data:
@@ -55,22 +315,24 @@ $ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --inp
Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453).
+These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912).
+
## Pre-trained models
-Description | Parameters | Dataset | Model
----|---:|---|---
-vq-wav2vec Gumbel
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 34.1M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt)
-vq-wav2vec K-means
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 33.0M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt)
-Roberta on K-means codes
([(Baevski et al., 2019)](https://arxiv.org/abs/1910.05453)) | 123.6M | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar)
+Description | Dataset | Model
+---|---|---
+vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt)
+vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt)
+Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar)
#### Example usage:
```python
import torch
-from fairseq.models.wav2vec import Wav2VecModel
+import fairseq
cp = torch.load('/path/to/vq-wav2vec.pt')
-model = Wav2VecModel.build_model(cp['args'], task=None)
-model.load_state_dict(cp['model'])
+model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp])
+model = model[0]
model.eval()
wav_input_16khz = torch.randn(1,10000)
@@ -93,14 +355,14 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa
```
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \
---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 \
---optimizer adam --max-lr 1e-05 --lr-scheduler cosine \
+--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \
+--optimizer adam --lr 1e-05 --lr-scheduler cosine \
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \
--log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \
--combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \
---warmup-init-lr 1e-07 --criterion binary_cross_entropy --num-negatives 10 --max-sample-size 150000 \
+--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \
--max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test
```
diff --git a/examples/wav2vec/__init__.py b/examples/wav2vec/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/examples/wav2vec/config/finetuning/base_100h.yaml b/examples/wav2vec/config/finetuning/base_100h.yaml
new file mode 100644
index 0000000000..153b5df170
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/base_100h.yaml
@@ -0,0 +1,58 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: false
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 2
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 80000
+ lr: [0.00003]
+ sentence_avg: true
+ update_freq: [4]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
diff --git a/examples/wav2vec/config/finetuning/base_10h.yaml b/examples/wav2vec/config/finetuning/base_10h.yaml
new file mode 100644
index 0000000000..5044518025
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/base_10h.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 50
+ save_interval_updates: 10000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: false
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 50
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 2
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 20000
+ lr: [0.00005]
+ sentence_avg: true
+ update_freq: [4]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.05
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/base_10m.yaml b/examples/wav2vec/config/finetuning/base_10m.yaml
new file mode 100644
index 0000000000..14abc013bd
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/base_10m.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 1000
+ save_interval_updates: 50
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: false
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 1000
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 2
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 13000
+ lr: [0.00005]
+ sentence_avg: true
+ update_freq: [4]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/base_1h.yaml b/examples/wav2vec/config/finetuning/base_1h.yaml
new file mode 100644
index 0000000000..14abc013bd
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/base_1h.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 1000
+ save_interval_updates: 50
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: false
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 1000
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 2
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 13000
+ lr: [0.00005]
+ sentence_avg: true
+ update_freq: [4]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml
new file mode 100644
index 0000000000..3eadc36b37
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/base_960h.yaml
@@ -0,0 +1,57 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: false
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 3200000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 8
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 320000
+ lr: [0.0001]
+ sentence_avg: true
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.1
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 0
diff --git a/examples/wav2vec/config/finetuning/vox_100h.yaml b/examples/wav2vec/config/finetuning/vox_100h.yaml
new file mode 100644
index 0000000000..b8f81e5e18
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_100h.yaml
@@ -0,0 +1,58 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: true
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 80000
+ lr: [0.00003]
+ sentence_avg: true
+ update_freq: [5]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.5
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/vox_10h.yaml b/examples/wav2vec/config/finetuning/vox_10h.yaml
new file mode 100644
index 0000000000..8f1ca71ee2
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10h.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 50
+ save_interval_updates: 10000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: true
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 50
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 20000
+ lr: [0.0001]
+ sentence_avg: true
+ update_freq: [5]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.75
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/vox_10m.yaml b/examples/wav2vec/config/finetuning/vox_10m.yaml
new file mode 100644
index 0000000000..07e327fe74
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_10m.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 1000
+ save_interval_updates: 50
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: true
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 1000
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 13000
+ lr: [0.0001]
+ sentence_avg: true
+ update_freq: [5]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.65
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/vox_1h.yaml b/examples/wav2vec/config/finetuning/vox_1h.yaml
new file mode 100644
index 0000000000..fac1bbb32f
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_1h.yaml
@@ -0,0 +1,63 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval: 1000
+ save_interval_updates: 50
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: true
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ validate_after_updates: 10000
+ validate_interval: 1000
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 4
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 13000
+ lr: [0.0003]
+ sentence_avg: true
+ update_freq: [5]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.75
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/finetuning/vox_960h.yaml b/examples/wav2vec/config/finetuning/vox_960h.yaml
new file mode 100644
index 0000000000..9d72404fa3
--- /dev/null
+++ b/examples/wav2vec/config/finetuning/vox_960h.yaml
@@ -0,0 +1,57 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ no_epoch_checkpoints: true
+ best_checkpoint_metric: wer
+
+task:
+ _name: audio_finetuning
+ data: ???
+ normalize: true
+ labels: ltr
+
+dataset:
+ num_workers: 6
+ max_tokens: 1280000
+ skip_invalid_size_inputs_valid_test: true
+ valid_subset: dev_other
+
+distributed_training:
+ ddp_backend: legacy_ddp
+ distributed_world_size: 24
+
+criterion:
+ _name: ctc
+ zero_infinity: true
+
+optimization:
+ max_update: 320000
+ lr: [0.00003]
+ sentence_avg: true
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-08
+
+lr_scheduler:
+ _name: tri_stage
+ phase_ratio: [0.1, 0.4, 0.5]
+ final_lr_scale: 0.05
+
+model:
+ _name: wav2vec_ctc
+ w2v_path: ???
+ apply_mask: true
+ mask_prob: 0.5
+ mask_channel_prob: 0.25
+ mask_channel_length: 64
+ layerdrop: 0.1
+ activation_dropout: 0.1
+ feature_grad_mult: 0.0
+ freeze_finetune_updates: 10000
diff --git a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml
new file mode 100644
index 0000000000..b686e21ab1
--- /dev/null
+++ b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml
@@ -0,0 +1,57 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: ???
+ max_sample_size: 250000
+ min_sample_size: 32000
+ normalize: false
+
+dataset:
+ num_workers: 6
+ max_tokens: 1400000
+ skip_invalid_size_inputs_valid_test: true
+
+distributed_training:
+ distributed_world_size: 64
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: wav2vec
+ infonce: true
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
+ loss_weights: [0.1, 10]
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: wav2vec2
+ quantize_targets: true
+ final_dim: 256
+ encoder_layerdrop: 0.05
+ dropout_input: 0.1
+ dropout_features: 0.1
+ feature_grad_mult: 0.1
+ encoder_embed_dim: 768
diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml
new file mode 100644
index 0000000000..3192ce4cba
--- /dev/null
+++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml
@@ -0,0 +1,70 @@
+# @package _group_
+
+common:
+ fp16: true
+ log_format: json
+ log_interval: 200
+
+checkpoint:
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: ???
+ max_sample_size: 320000
+ min_sample_size: 32000
+ normalize: true
+
+dataset:
+ batch_size: 4
+ num_workers: 6
+ max_tokens: 1200000
+ skip_invalid_size_inputs_valid_test: true
+
+distributed_training:
+ distributed_world_size: 128
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: wav2vec
+ infonce: true
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
+ loss_weights: [0.1, 0]
+
+optimization:
+ max_update: 1000000
+ lr: [0.005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: wav2vec2
+ quantize_targets: true
+ extractor_mode: layer_norm
+ layer_norm_first: true
+ final_dim: 768
+ latent_temp: [2.0,0.1,0.999995]
+ encoder_layerdrop: 0.00
+ dropout_input: 0.0
+ dropout_features: 0.0
+ dropout: 0.0
+ attention_dropout: 0.0
+ conv_bias: true
+
+ encoder_layers: 24
+ encoder_embed_dim: 1024
+ encoder_ffn_embed_dim: 4096
+ encoder_attention_heads: 16
+
+ feature_grad_mult: 1.0
+
diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml
new file mode 100644
index 0000000000..ff35a95b65
--- /dev/null
+++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml
@@ -0,0 +1,72 @@
+# @package _group_
+
+common:
+ tpu: true
+ fp16: false
+ log_format: json
+ log_interval: 10
+
+checkpoint:
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: ???
+ max_sample_size: 250000
+ min_sample_size: 32000
+ normalize: true
+ num_batch_buckets: 3
+ precompute_mask_indices: true
+ enable_padding: true
+
+dataset:
+ num_workers: 6
+ max_tokens: 1200000
+ skip_invalid_size_inputs_valid_test: true
+
+distributed_training:
+ distributed_world_size: 128
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: wav2vec
+ infonce: true
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
+ loss_weights: [0.1, 0]
+
+optimization:
+ max_update: 1000000
+ lr: [0.005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: wav2vec2
+ quantize_targets: true
+ extractor_mode: layer_norm
+ layer_norm_first: true
+ final_dim: 768
+ latent_temp: [2.0,0.1,0.999995]
+ encoder_layerdrop: 0.00
+ dropout_input: 0.0
+ dropout_features: 0.0
+ dropout: 0.0
+ attention_dropout: 0.0
+ conv_bias: true
+
+ encoder_layers: 24
+ encoder_embed_dim: 1024
+ encoder_ffn_embed_dim: 4096
+ encoder_attention_heads: 16
+
+ feature_grad_mult: 1.0
diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml
new file mode 100644
index 0000000000..ee55bdab72
--- /dev/null
+++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml
@@ -0,0 +1,77 @@
+# @package _group_
+
+common:
+ tpu: true
+ fp16: false
+ log_format: json
+ log_interval: 10
+
+checkpoint:
+ save_interval_updates: 25000
+ keep_interval_updates: 1
+ no_epoch_checkpoints: true
+
+task:
+ _name: audio_pretraining
+ data: ???
+ max_sample_size: 250000
+ min_sample_size: 32000
+ normalize: true
+ num_batch_buckets: 3
+ precompute_mask_indices: true
+ enable_padding: true
+ inferred_w2v_config:
+ mask_prob: 0.65
+ mask_selection: 'static'
+ mask_other: 0
+ mask_channel_prob: 0.1
+
+dataset:
+ num_workers: 6
+ max_tokens: 1200000
+ skip_invalid_size_inputs_valid_test: true
+
+distributed_training:
+ distributed_world_size: 8
+ ddp_backend: legacy_ddp
+
+criterion:
+ _name: wav2vec
+ infonce: true
+ log_keys: ["prob_perplexity","code_perplexity","temp"]
+ loss_weights: [0.1, 0]
+
+optimization:
+ max_update: 1000000
+ lr: [0.005]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: wav2vec2
+ quantize_targets: true
+ extractor_mode: layer_norm
+ layer_norm_first: true
+ final_dim: 768
+ latent_temp: [2.0,0.1,0.999995]
+ encoder_layerdrop: 0.00
+ dropout_input: 0.0
+ dropout_features: 0.0
+ dropout: 0.0
+ attention_dropout: 0.0
+ conv_bias: true
+
+ encoder_layers: 24
+ encoder_embed_dim: 1024
+ encoder_ffn_embed_dim: 4096
+ encoder_attention_heads: 16
+
+ feature_grad_mult: 1.0
diff --git a/examples/wav2vec/libri_labels.py b/examples/wav2vec/libri_labels.py
new file mode 100644
index 0000000000..694a202604
--- /dev/null
+++ b/examples/wav2vec/libri_labels.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset
+"""
+
+import argparse
+import os
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("tsv")
+ parser.add_argument("--output-dir", required=True)
+ parser.add_argument("--output-name", required=True)
+ args = parser.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ transcriptions = {}
+
+ with open(args.tsv, "r") as tsv, open(
+ os.path.join(args.output_dir, args.output_name + ".ltr"), "w"
+ ) as ltr_out, open(
+ os.path.join(args.output_dir, args.output_name + ".wrd"), "w"
+ ) as wrd_out:
+ root = next(tsv).strip()
+ for line in tsv:
+ line = line.strip()
+ dir = os.path.dirname(line)
+ if dir not in transcriptions:
+ parts = dir.split(os.path.sep)
+ trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt"
+ path = os.path.join(root, dir, trans_path)
+ assert os.path.exists(path)
+ texts = {}
+ with open(path, "r") as trans_f:
+ for tline in trans_f:
+ items = tline.strip().split()
+ texts[items[0]] = " ".join(items[1:])
+ transcriptions[dir] = texts
+ part = os.path.basename(line).split(".")[0]
+ assert part in transcriptions[dir]
+ print(transcriptions[dir][part], file=wrd_out)
+ print(
+ " ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |",
+ file=ltr_out,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/wav2vec/scripts/binarize_manifest.sh b/examples/wav2vec/scripts/binarize_manifest.sh
new file mode 100644
index 0000000000..6f201bdb52
--- /dev/null
+++ b/examples/wav2vec/scripts/binarize_manifest.sh
@@ -0,0 +1,33 @@
+#!/usr/bin/env bash
+
+# usage: bash binarize_manifest